//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Detector/IDetector.cpp
//! @brief     Implements common detector interface.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Detector/IDetector.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/Scale.h"
#include "Base/Util/Assert.h"
#include "Device/Beam/Beam.h"
#include "Device/Detector/SimulationAreaIterator.h"
#include "Device/Mask/DetectorMask.h"
#include "Device/Mask/InfinitePlane.h"
#include "Device/Resolution/ConvolutionDetectorResolution.h"
#include <iostream>

namespace {

inline size_t xcoord(size_t i, size_t sizeX, size_t sizeY)
{
    return i / sizeY % sizeX;
}

inline size_t ycoord(size_t i, size_t sizeY)
{
    return i % sizeY;
}

} // namespace


//... Auxiliary class RoiOfAxis

IDetector::RoiOfAxis::RoiOfAxis(const Scale& axis, double _lower, double _upper)
    : lower(_lower)
    , upper(_upper)
{
    ASSERT(lower < upper);
    detectorSize = axis.size();
    lowerIndex = axis.closestIndex(lower);
    upperIndex = axis.closestIndex(upper);
    // suppress tiny bins that are most likely due to floating-point inaccuracy
    if (axis.bin(lowerIndex).binSize() < 1e-12 * axis.span() / axis.size()) {
        ASSERT(lowerIndex < axis.size() - 1);
        ++lowerIndex;
    }
    if (axis.bin(upperIndex).binSize() < 1e-12 * axis.span() / axis.size()) {
        ASSERT(upperIndex > 0);
        --upperIndex;
    }
    roiSize = upperIndex - lowerIndex + 1;
}

std::pair<double, double> IDetector::RoiOfAxis::bounds() const
{
    return {lower, upper};
}


IDetector::IDetector() {}

IDetector::IDetector(const IDetector& other)
    : INode()
    , m_explicitROI(other.m_explicitROI)
    , m_frame(other.m_frame->clone())
    , m_polAnalyzer(other.m_polAnalyzer)
    , m_resolution(other.m_resolution ? other.m_resolution->clone() : nullptr)
    , m_mask(std::make_unique<DetectorMask>(*other.m_mask))
{
}

IDetector::~IDetector() = default;

void IDetector::setFrame(Frame* frame)
{
    ASSERT(frame->rank() == 2);
    m_frame.reset(frame);
    m_mask.reset(new DetectorMask(m_frame->axis(0), m_frame->axis(1)));
}

const Scale& IDetector::axis(size_t i) const
{
    ASSERT(i < 2);
    return m_frame->axis(i);
}

size_t IDetector::axisBinIndex(size_t i, size_t selected_axis) const
{
    size_t remainder(i);
    size_t i_axis = 2;
    for (size_t i = 0; i < 2; ++i) {
        --i_axis;
        const Scale& ax = m_frame->axis(i_axis);
        if (selected_axis == i_axis)
            return remainder % ax.size();
        remainder /= ax.size();
    }
    ASSERT_NEVER;
}

size_t IDetector::sizeOfExplicitRegionOfInterest() const
{
    if (m_explicitROI.size() != 2)
        return 0;

    return m_explicitROI[0].roiSize * m_explicitROI[1].roiSize;
}

size_t IDetector::totalSize() const
{
    return m_frame->axis(0).size() * m_frame->axis(1).size();
}

size_t IDetector::sizeOfRegionOfInterest() const
{
    const auto explicitSize = sizeOfExplicitRegionOfInterest();
    return (explicitSize != 0) ? explicitSize : totalSize();
}

bool IDetector::hasExplicitRegionOfInterest() const
{
    return m_explicitROI.size() == 2;
}

Frame IDetector::clippedFrame() const
{
    std::vector<const Scale*> axes;
    for (size_t i = 0; i < 2; ++i)
        axes.emplace_back(new Scale(m_frame->axis(i).clipped(regionOfInterestBounds(i))));
    ASSERT(m_frame);
    Frame result = *m_frame;
    result.setAxes(axes);
    return result;
}

void IDetector::setAnalyzer(const R3 Bloch_vector, double mean_transmission)
{
    m_polAnalyzer = PolFilter(Bloch_vector, mean_transmission);
}

void IDetector::setAnalyzer(const R3 direction, double efficiency, double mean_transmission)
{
    std::cout
        << "Function setAnalyzer(direction, efficiency, transmission) is obsolete since "
           "BornAgain v21,\n"
           "and will eventually be removed. Use setAnalyzer(Bloch_vector, transmission) instead.\n";
    setAnalyzer(direction * efficiency, mean_transmission);
}

void IDetector::setDetectorResolution(const IDetectorResolution& detector_resolution)
{
    m_resolution.reset(detector_resolution.clone());
}

// TODO: pass dimension-independent argument to this function
void IDetector::setResolutionFunction(const IResolutionFunction2D& resFunc)
{
    ConvolutionDetectorResolution convFunc(resFunc);
    setDetectorResolution(convFunc);
}

void IDetector::applyDetectorResolution(Datafield* intensity_map) const
{
    if (!m_resolution)
        return;

    ASSERT(intensity_map);

    m_resolution->applyDetectorResolution(intensity_map);
    if (detectorMask() && detectorMask()->hasMasks()) {
        // sets amplitude in masked areas to zero
        auto buff = std::make_unique<Datafield>(intensity_map->frame().clone());
        iterateOverNonMaskedPoints(
            [&](const_iterator it) { (*buff)[it.roiIndex()] = (*intensity_map)[it.roiIndex()]; });
        intensity_map->setVector(buff->flatVector());
    }
}

Datafield IDetector::createDetectorMap() const
{
    std::vector<const Scale*> axes;
    for (size_t i = 0; i < 2; ++i)
        axes.emplace_back(new Scale(axis(i).clipped(regionOfInterestBounds(i))));
    ASSERT(m_frame);
    auto* f = new Frame(*m_frame);
    f->setAxes(std::move(axes));
    return Datafield(f);
}

std::pair<double, double> IDetector::regionOfInterestBounds(size_t iAxis) const
{
    ASSERT(iAxis < 2);
    if (iAxis < m_explicitROI.size())
        return m_explicitROI[iAxis].bounds();
    return m_frame->axis(iAxis).bounds();
}

std::vector<const INode*> IDetector::nodeChildren() const
{
    return std::vector<const INode*>() << &m_polAnalyzer << m_resolution.get();
}

void IDetector::iterateOverNonMaskedPoints(std::function<void(const_iterator)> func) const
{
    for (auto it = SimulationAreaIterator::createBegin(this);
         it != SimulationAreaIterator::createEnd(this); ++it)
        func(it);
}

size_t IDetector::regionOfInterestIndexToDetectorIndex(const size_t i) const
{
    if (m_explicitROI.size() != 2)
        return i;

    const auto& x = m_explicitROI[0];
    const auto& y = m_explicitROI[1];

    const size_t globalIndex0 = y.lowerIndex + x.lowerIndex * y.detectorSize;
    return globalIndex0 + ycoord(i, y.roiSize) + xcoord(i, x.roiSize, y.roiSize) * y.detectorSize;
}

void IDetector::setRegionOfInterest(double xlow, double ylow, double xup, double yup)
{
    m_explicitROI.clear();
    m_explicitROI.emplace_back(axis(0), xlow, xup);
    m_explicitROI.emplace_back(axis(1), ylow, yup);
}

std::vector<size_t> IDetector::active_indices() const
{
    std::vector<size_t> result;

    iterateOverNonMaskedPoints([&](const_iterator it) { result.push_back(it.detectorIndex()); });

    return result;
}

void IDetector::addMask(const IShape2D& shape, bool mask_value)
{
    m_mask->addMask(shape, mask_value);
}

void IDetector::maskAll()
{
    addMask(InfinitePlane(), true);
}

const DetectorMask* IDetector::detectorMask() const
{
    return m_mask.get();
}

size_t IDetector::getGlobalIndex(size_t x, size_t y) const
{
    return x * axis(1).size() + y;
}