/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkInvertDisplacementFieldImageFilter_hxx
#define itkInvertDisplacementFieldImageFilter_hxx


#include "itkComposeDisplacementFieldsImageFilter.h"
#include "itkImageDuplicator.h"
#include "itkImageRegionIterator.h"
#include <mutex>
#include "itkProgressTransformer.h"

namespace itk
{

/*
 * InvertDisplacementFieldImageFilter class definitions
 */
template <typename TInputImage, typename TOutputImage>
InvertDisplacementFieldImageFilter<TInputImage, TOutputImage>::InvertDisplacementFieldImageFilter()
  : m_Interpolator(DefaultInterpolatorType::New())
  , m_MaxErrorToleranceThreshold(0.1)
  , m_MeanErrorToleranceThreshold(0.001)
  , m_ComposedField(DisplacementFieldType::New())
  , m_ScaledNormImage(RealImageType::New())
  , m_MaxErrorNorm(0.0)
  , m_MeanErrorNorm(0.0)
  , m_Epsilon(0.0)

{
  this->SetNumberOfRequiredInputs(1);
  this->DynamicMultiThreadingOn();
}

template <typename TInputImage, typename TOutputImage>
void
InvertDisplacementFieldImageFilter<TInputImage, TOutputImage>::SetInterpolator(InterpolatorType * interpolator)
{
  itkDebugMacro("setting Interpolator to " << interpolator);
  if (this->m_Interpolator != interpolator)
  {
    this->m_Interpolator = interpolator;
    this->Modified();
    if (!this->GetDisplacementField())
    {
      this->m_Interpolator->SetInputImage(this->GetInput(0));
    }
  }
}

template <typename TInputImage, typename TOutputImage>
void
InvertDisplacementFieldImageFilter<TInputImage, TOutputImage>::GenerateData()
{
  this->UpdateProgress(0.0f);
  this->AllocateOutputs();

  constexpr VectorType zeroVector{};

  typename DisplacementFieldType::ConstPointer displacementField = this->GetInput();

  typename InverseDisplacementFieldType::Pointer inverseDisplacementField;

  if (this->GetInverseFieldInitialEstimate())
  {
    using DuplicatorType = ImageDuplicator<InverseDisplacementFieldType>;
    auto duplicator = DuplicatorType::New();
    duplicator->SetInputImage(this->GetInverseFieldInitialEstimate());
    duplicator->Update();

    inverseDisplacementField = duplicator->GetOutput();

    this->SetNthOutput(0, inverseDisplacementField);
  }
  else
  {
    inverseDisplacementField = this->GetOutput();
    inverseDisplacementField->FillBuffer(zeroVector);
  }

  for (unsigned int d = 0; d < ImageDimension; ++d)
  {
    this->m_DisplacementFieldSpacing[d] = displacementField->GetSpacing()[d];
  }

  this->m_ScaledNormImage->CopyInformation(displacementField);
  this->m_ScaledNormImage->SetRegions(displacementField->GetRequestedRegion());
  this->m_ScaledNormImage->AllocateInitialized();

  SizeValueType numberOfPixelsInRegion = (displacementField->GetRequestedRegion()).GetNumberOfPixels();
  this->m_MaxErrorNorm = NumericTraits<RealType>::max();
  this->m_MeanErrorNorm = NumericTraits<RealType>::max();
  unsigned int iteration = 0;

  float oldProgress = 0.0f;

  while (iteration++ < this->m_MaximumNumberOfIterations && this->m_MaxErrorNorm > this->m_MaxErrorToleranceThreshold &&
         this->m_MeanErrorNorm > this->m_MeanErrorToleranceThreshold)
  {
    itkDebugMacro("Iteration " << iteration << ": mean error norm = " << this->m_MeanErrorNorm
                               << ", max error norm = " << this->m_MaxErrorNorm);

    using ComposerType = ComposeDisplacementFieldsImageFilter<DisplacementFieldType>;
    auto composer = ComposerType::New();
    composer->SetDisplacementField(displacementField);
    composer->SetWarpingField(inverseDisplacementField);

    this->m_ComposedField = composer->GetOutput();
    this->m_ComposedField->Update();
    this->m_ComposedField->DisconnectPipeline();

    // Multithread processing to multiply each element of the composed field by 1 / spacing
    this->m_MeanErrorNorm = RealType{};
    this->m_MaxErrorNorm = RealType{};

    float               newProgress = static_cast<float>(2 * iteration - 1) / (2 * m_MaximumNumberOfIterations);
    ProgressTransformer pt(oldProgress, newProgress, this);
    this->m_DoThreadedEstimateInverse = false;
    this->GetMultiThreader()->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
    this->GetMultiThreader()->template ParallelizeImageRegion<TOutputImage::ImageDimension>(
      this->GetOutput()->GetRequestedRegion(),
      [this](const OutputImageRegionType & outputRegionForThread) {
        this->DynamicThreadedGenerateData(outputRegionForThread);
      },
      pt.GetProcessObject());

    this->m_MeanErrorNorm /= static_cast<RealType>(numberOfPixelsInRegion);

    this->m_Epsilon = 0.5;
    if (iteration == 1)
    {
      this->m_Epsilon = 0.75;
    }

    oldProgress = newProgress;
    newProgress = static_cast<float>(2 * iteration) / (2 * m_MaximumNumberOfIterations);
    ProgressTransformer pt2(oldProgress, newProgress, this);
    // Multithread processing to estimate inverse field
    this->m_DoThreadedEstimateInverse = true;
    this->GetMultiThreader()->template ParallelizeImageRegion<TOutputImage::ImageDimension>(
      this->GetOutput()->GetRequestedRegion(),
      [this](const OutputImageRegionType & outputRegionForThread) {
        this->DynamicThreadedGenerateData(outputRegionForThread);
      },
      pt2.GetProcessObject());
    oldProgress = newProgress;
  }

  this->UpdateProgress(1.0f);
}

template <typename TInputImage, typename TOutputImage>
void
InvertDisplacementFieldImageFilter<TInputImage, TOutputImage>::DynamicThreadedGenerateData(const RegionType & region)
{
  const typename DisplacementFieldType::RegionType fullRegion = this->m_ComposedField->GetRequestedRegion();
  const typename DisplacementFieldType::SizeType   size = fullRegion.GetSize();
  const typename DisplacementFieldType::IndexType  startIndex = fullRegion.GetIndex();
  const typename DisplacementFieldType::PixelType  zeroVector{};

  ImageRegionIterator<DisplacementFieldType> ItE(this->m_ComposedField, region);
  ImageRegionIterator<RealImageType>         ItS(this->m_ScaledNormImage, region);

  if (this->m_DoThreadedEstimateInverse)
  {
    ImageRegionIterator<DisplacementFieldType> ItI(this->GetOutput(), region);

    for (ItI.GoToBegin(), ItE.GoToBegin(), ItS.GoToBegin(); !ItI.IsAtEnd(); ++ItI, ++ItE, ++ItS)
    {
      VectorType update = ItE.Get();
      RealType   scaledNorm = ItS.Get();

      if (scaledNorm > this->m_Epsilon * this->m_MaxErrorNorm)
      {
        update *= (this->m_Epsilon * this->m_MaxErrorNorm / scaledNorm);
      }
      update = ItI.Get() + update * this->m_Epsilon;
      ItI.Set(update);
      typename DisplacementFieldType::IndexType index = ItI.GetIndex();
      if (this->m_EnforceBoundaryCondition)
      {
        for (unsigned int d = 0; d < ImageDimension; ++d)
        {
          if (index[d] == startIndex[d] || index[d] == static_cast<IndexValueType>(size[d]) - startIndex[d] - 1)
          {
            ItI.Set(zeroVector);
            break;
          }
        }
      } // enforce boundary condition
    }
  }
  else
  {
    VectorType inverseSpacing;
    RealType   localMean{};
    RealType   localMax{};
    for (unsigned int d = 0; d < ImageDimension; ++d)
    {
      inverseSpacing[d] = 1.0 / this->m_DisplacementFieldSpacing[d];
    }
    for (ItE.GoToBegin(), ItS.GoToBegin(); !ItE.IsAtEnd(); ++ItE, ++ItS)
    {
      const VectorType & displacement = ItE.Get();
      RealType           scaledNorm = 0.0;
      for (unsigned int d = 0; d < ImageDimension; ++d)
      {
        scaledNorm += itk::Math::sqr(displacement[d] * inverseSpacing[d]);
      }
      scaledNorm = std::sqrt(scaledNorm);

      localMean += scaledNorm;
      if (localMax < scaledNorm)
      {
        localMax = scaledNorm;
      }

      ItS.Set(scaledNorm);
      ItE.Set(-displacement);
    }
    {
      const std::lock_guard<std::mutex> lockGuard(m_Mutex);
      this->m_MeanErrorNorm += localMean;
      if (this->m_MaxErrorNorm < localMax)
      {
        this->m_MaxErrorNorm = localMax;
      }
    }
  }
}

template <typename TInputImage, typename TOutputImage>
void
InvertDisplacementFieldImageFilter<TInputImage, TOutputImage>::PrintSelf(std::ostream & os, Indent indent) const
{
  Superclass::PrintSelf(os, indent);

  itkPrintSelfObjectMacro(Interpolator);

  os << indent << "Maximum number of iterations: " << this->m_MaximumNumberOfIterations << std::endl;
  os << indent << "Max error tolerance threshold: " << this->m_MaxErrorToleranceThreshold << std::endl;
  os << indent << "Mean error tolerance threshold: " << this->m_MeanErrorToleranceThreshold << std::endl;
}

} // end namespace itk

#endif
