/* This file is part of the Palabos library.
 *
 * The Palabos softare is developed since 2011 by FlowKit-Numeca Group Sarl
 * (Switzerland) and the University of Geneva (Switzerland), which jointly
 * own the IP rights for most of the code base. Since October 2019, the
 * Palabos project is maintained by the University of Geneva and accepts
 * source code contributions from the community.
 *
 * Contact:
 * Jonas Latt
 * Computer Science Department
 * University of Geneva
 * 7 Route de Drize
 * 1227 Carouge, Switzerland
 * jonas.latt@unige.ch
 *
 * The most recent release of Palabos can be downloaded at
 * <https://palabos.unige.ch/>
 *
 * The library Palabos is free software: you can redistribute it and/or
 * modify it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * The library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <cmath>
#include <fstream>
#include <iostream>
#include <vector>

#include "palabos3D.h"
#include "palabos3D.hh"

using namespace plb;
using namespace std;

typedef double T;
#define DESCRIPTOR descriptors::D3Q19Descriptor
#define DYNAMICS   BGKdynamics<T, DESCRIPTOR>(parameters.getOmega())

#define NMAX 150

const T pi = (T)4. * std::atan((T)1.);

static T poiseuillePressure(IncomprFlowParam<T> const &parameters, plint maxN)
{
    const T a = parameters.getNx() - 1;
    const T b = parameters.getNy() - 1;

    const T nu = parameters.getLatticeNu();
    const T uMax = parameters.getLatticeU();

    T sum = T();
    for (plint iN = 0; iN < maxN; iN += 2) {
        T twoNplusOne = (T)2 * (T)iN + (T)1;
        sum +=
            ((T)1 / (std::pow(twoNplusOne, (T)3) * std::cosh(twoNplusOne * pi * b / ((T)2 * a))));
    }
    for (plint iN = 1; iN < maxN; iN += 2) {
        T twoNplusOne = (T)2 * (T)iN + (T)1;
        sum -=
            ((T)1 / (std::pow(twoNplusOne, (T)3) * std::cosh(twoNplusOne * pi * b / ((T)2 * a))));
    }

    T alpha = -(T)8 * uMax * pi * pi * pi
              / (a * a * (pi * pi * pi - (T)32 * sum));  // alpha = -dp/dz / mu

    T deltaP = -(alpha * nu);

    return deltaP;
}

T poiseuilleVelocity(plint iX, plint iY, IncomprFlowParam<T> const &parameters, plint maxN)
{
    const T a = parameters.getNx() - 1;
    const T b = parameters.getNy() - 1;

    const T x = (T)iX - a / (T)2;
    const T y = (T)iY - b / (T)2;

    const T alpha = -poiseuillePressure(parameters, maxN) / parameters.getLatticeNu();

    T sum = T();

    for (plint iN = 0; iN < maxN; iN += 2) {
        T twoNplusOne = (T)2 * (T)iN + (T)1;

        sum +=
            (std::cos(twoNplusOne * pi * x / a) * std::cosh(twoNplusOne * pi * y / a)
             / (std::pow(twoNplusOne, (T)3) * std::cosh(twoNplusOne * pi * b / ((T)2 * a))));
    }
    for (plint iN = 1; iN < maxN; iN += 2) {
        T twoNplusOne = (T)2 * (T)iN + (T)1;

        sum -=
            (std::cos(twoNplusOne * pi * x / a) * std::cosh(twoNplusOne * pi * y / a)
             / (std::pow(twoNplusOne, (T)3) * std::cosh(twoNplusOne * pi * b / ((T)2 * a))));
    }

    sum *= ((T)4 * alpha * a * a / std::pow(pi, (T)3));
    sum += (alpha / (T)2 * (x * x - a * a / (T)4));

    return sum;
}

template <typename T>
class SquarePoiseuilleDensityAndVelocity {
public:
    SquarePoiseuilleDensityAndVelocity(IncomprFlowParam<T> const &parameters_, plint maxN_) :
        parameters(parameters_), maxN(maxN_)
    { }
    void operator()(plint iX, plint iY, plint, T &rho, Array<T, 3> &u) const
    {
        rho = (T)1;
        u[0] = T();
        u[1] = T();
        u[2] = poiseuilleVelocity(iX, iY, parameters, maxN);
    }

private:
    IncomprFlowParam<T> parameters;
    plint maxN;
};

template <typename T>
class SquarePoiseuilleVelocity {
public:
    SquarePoiseuilleVelocity(IncomprFlowParam<T> const &parameters_, plint maxN_) :
        parameters(parameters_), maxN(maxN_)
    { }
    void operator()(plint iX, plint iY, plint, Array<T, 3> &u) const
    {
        u[0] = T();
        u[1] = T();
        u[2] = poiseuilleVelocity(iX, iY, parameters, maxN);
    }

private:
    IncomprFlowParam<T> parameters;
    plint maxN;
};

void squarePoiseuilleSetup(
    MultiBlockLattice3D<T, DESCRIPTOR> &lattice, IncomprFlowParam<T> const &parameters,
    OnLatticeBoundaryCondition3D<T, DESCRIPTOR> &boundaryCondition)
{
    const plint nx = parameters.getNx();
    const plint ny = parameters.getNy();
    const plint nz = parameters.getNz();
    Box3D top = Box3D(0, nx - 1, ny - 1, ny - 1, 0, nz - 1);
    Box3D bottom = Box3D(0, nx - 1, 0, 0, 0, nz - 1);

    Box3D inlet = Box3D(0, nx - 1, 0, ny - 1, 0, 0);
    Box3D outlet = Box3D(0, nx - 1, 0, ny - 1, nz - 1, nz - 1);

    Box3D left = Box3D(0, 0, 1, ny - 2, 0, nz - 1);
    Box3D right = Box3D(nx - 1, nx - 1, 1, ny - 2, 0, nz - 1);

    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, inlet);
    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, outlet);

    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, top);
    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, bottom);

    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, left);
    boundaryCondition.setVelocityConditionOnBlockBoundaries(lattice, right);

    setBoundaryVelocity(lattice, inlet, SquarePoiseuilleVelocity<T>(parameters, NMAX));
    setBoundaryVelocity(lattice, outlet, SquarePoiseuilleVelocity<T>(parameters, NMAX));

    setBoundaryVelocity(lattice, top, Array<T, 3>((T)0.0, (T)0.0, (T)0.0));
    setBoundaryVelocity(lattice, bottom, Array<T, 3>((T)0.0, (T)0.0, (T)0.0));
    setBoundaryVelocity(lattice, left, Array<T, 3>((T)0.0, (T)0.0, (T)0.0));
    setBoundaryVelocity(lattice, right, Array<T, 3>((T)0.0, (T)0.0, (T)0.0));

    initializeAtEquilibrium(
        lattice, lattice.getBoundingBox(), SquarePoiseuilleDensityAndVelocity<T>(parameters, NMAX));

    lattice.initialize();
}

T computeRMSerror(
    MultiBlockLattice3D<T, DESCRIPTOR> &lattice, IncomprFlowParam<T> const &parameters)
{
    MultiTensorField3D<T, 3> analyticalVelocity(lattice);
    setToFunction(
        analyticalVelocity, analyticalVelocity.getBoundingBox(),
        SquarePoiseuilleVelocity<T>(parameters, NMAX));
    MultiTensorField3D<T, 3> numericalVelocity(lattice);
    computeVelocity(lattice, numericalVelocity, lattice.getBoundingBox());

    // Divide by lattice velocity to normalize the error
    return 1. / parameters.getLatticeU() *
           // Compute RMS difference between analytical and numerical solution
           std::sqrt(
               computeAverage(*computeNormSqr(*subtract(analyticalVelocity, numericalVelocity))));
}

int main(int argc, char *argv[])
{
    plbInit(&argc, &argv);
    global::directories().setOutputDir("./tmp/");

    if (argc != 2) {
        pcout << "Error the parameters are wrong. The structure must be :\n";
        pcout << "1 : N\n";
        exit(1);
    }

    const plint N = atoi(argv[1]);
    const T Re = 10.0;
    const plint Nref = 50;
    const T uMaxRef = 0.01;
    const T uMax = uMaxRef / (T)N * (T)Nref;  // Needed to avoid compressibility errors.

    IncomprFlowParam<T> parameters(
        uMax, Re, N,
        1.,  // lx
        1.,  // ly
        1.   // lz
    );
    const T maxT = (T)10.0;

    writeLogFile(parameters, "3D square Poiseuille");

    MultiBlockLattice3D<T, DESCRIPTOR> lattice(
        parameters.getNx(), parameters.getNy(), parameters.getNz(), new DYNAMICS);

    // Use periodic boundary conditions.
    lattice.periodicity().toggle(2, true);

    OnLatticeBoundaryCondition3D<T, DESCRIPTOR> *boundaryCondition =
        createLocalBoundaryCondition3D<T, DESCRIPTOR>();

    squarePoiseuilleSetup(lattice, parameters, *boundaryCondition);

    // Loop over main time iteration.
    util::ValueTracer<T> converge(parameters.getLatticeU(), parameters.getResolution(), 1.0e-3);
    for (plint iT = 0; iT < parameters.nStep(maxT); ++iT) {
        converge.takeValue(getStoredAverageEnergy(lattice), true);
        if (converge.hasConverged()) {
            pcout << "Simulation converged." << endl;
            break;
        }

        // Execute a time iteration.
        lattice.collideAndStream();
    }
    pcout << "For N = " << N << ", Error = " << computeRMSerror(lattice, parameters) << endl;

    delete boundaryCondition;
}
