//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Resample/Particle/ReMesocrystal.cpp
//! @brief     Implements class ReMesocrystal.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Resample/Particle/ReMesocrystal.h"
#include <numbers>
using std::numbers::pi;
#include "Base/Spin/SpinMatrix.h"
#include "Base/Types/Span.h"
#include "Base/Vector/WavevectorInfo.h"
#include "Resample/Particle/ReParticle.h"

ReMesocrystal::ReMesocrystal(const std::optional<size_t>& i_layer, const Lattice3D& lattice,
                             const IReParticle& basis, const ReParticle& outer_shape,
                             double position_variance)
    : IReParticle(i_layer)
    , m_lattice(lattice)
    , m_basis(basis.clone())
    , m_outer_shape(outer_shape.clone())
    , m_position_variance(position_variance)
{
    calculateLargestReciprocalDistance();
}

ReMesocrystal::~ReMesocrystal() = default;

ReMesocrystal* ReMesocrystal::clone() const
{
    return new ReMesocrystal(i_layer(), m_lattice, *m_basis, *m_outer_shape, m_position_variance);
}

double ReMesocrystal::volume() const
{
    return m_outer_shape->volume();
}

double ReMesocrystal::radialExtension() const
{
    return m_outer_shape->radialExtension();
}

Span ReMesocrystal::zSpan() const
{
    return m_outer_shape->zSpan();
}

complex_t ReMesocrystal::theFF(const WavevectorInfo& wavevectors) const
{
    // retrieve reciprocal lattice vectors within reasonable radius
    C3 q = wavevectors.getQ();
    double radius = 2.1 * m_max_rec_length;
    std::vector<R3> rec_vectors = m_lattice.reciprocalLatticeVectorsWithinRadius(q.real(), radius);

    // perform convolution on these lattice vectors
    complex_t result(0.0, 0.0);
    for (const auto& rec : rec_vectors) {
        auto dw_factor = debyeWallerFactor(rec);
        WavevectorInfo basis_wavevectors(R3(), -rec, wavevectors.vacuumLambda());
        complex_t basis_factor = m_basis->theFF(basis_wavevectors);
        WavevectorInfo meso_wavevectors(C3(), rec.complex() - q, wavevectors.vacuumLambda());
        complex_t meso_factor = m_outer_shape->theFF(meso_wavevectors);
        result += dw_factor * basis_factor * meso_factor;
    }
    // the transformed delta train gets a factor of (2pi)^3/V, but the (2pi)^3
    // is canceled by the convolution of Fourier transforms :
    return result / m_lattice.unitCellVolume();
}

SpinMatrix ReMesocrystal::thePolFF(const WavevectorInfo& wavevectors) const
{
    // retrieve reciprocal lattice vectors within reasonable radius
    C3 q = wavevectors.getQ();
    double radius = 2.1 * m_max_rec_length;
    std::vector<R3> rec_vectors = m_lattice.reciprocalLatticeVectorsWithinRadius(q.real(), radius);

    // perform convolution on these lattice vectors
    SpinMatrix result;
    for (const auto& rec : rec_vectors) {
        auto dw_factor = debyeWallerFactor(rec);
        WavevectorInfo basis_wavevectors(R3(), -rec, wavevectors.vacuumLambda());
        SpinMatrix basis_factor = m_basis->thePolFF(basis_wavevectors);
        WavevectorInfo meso_wavevectors(C3(), rec.complex() - q, wavevectors.vacuumLambda());
        complex_t meso_factor = m_outer_shape->theFF(meso_wavevectors);
        result += dw_factor * basis_factor * meso_factor;
    }
    // the transformed delta train gets a factor of (2pi)^3/V, but the (2pi)^3
    // is canceled by the convolution of Fourier transforms :
    return result / m_lattice.unitCellVolume();
}

void ReMesocrystal::calculateLargestReciprocalDistance()
{
    R3 a1 = m_lattice.basisVectorA();
    R3 a2 = m_lattice.basisVectorB();
    R3 a3 = m_lattice.basisVectorC();

    m_max_rec_length = std::max(pi / a1.mag(), pi / a2.mag());
    m_max_rec_length = std::max(m_max_rec_length, pi / a3.mag());
}

complex_t ReMesocrystal::debyeWallerFactor(const R3& q_i) const
{
    auto q2 = q_i.mag2();
    return std::exp(-q2 * m_position_variance / 2.0);
}
