#include <algorithm>
#include <cuda_runtime.h>
#include <cuda.h>
#include <curand.h>
#include <curand_kernel.h>
#include <algorithm>
#include <numeric>
#include <iostream>
#include <fstream>
#include <type_traits>
#include <sstream>
#include <limits>
#ifndef _MSC_VER
#include <bits/stdc++.h>
#endif
#ifdef _MSC_VER
#include <corecrt_math_defines.h>
#endif

#include <thrust/scan.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/device_ptr.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <thrust/iterator/counting_iterator.h>

#include "../Commons/Physics.h"
#include "SimulatorMultiPart.cuh"
#include "kernels.cuh"
#include "logDistrib.hpp"
#include "normalDistrib.hpp"
#include "uniformDistrib.hpp"
#include "../Commons/utilityKernel.cuh"
#include "optimizer.hpp"

bool testFile(std::string path)
{
    std::ifstream file(path);
    bool res = file.is_open();

    if (res)
    {
        file.close();
    }
    return res;
}

void SimulatorMultiPart::save()
{
    std::string path = cpuSim.outputDir + "/XYZ/" + format_account_number(saveCount) + ".xyz";
    std::ostringstream oss;
    oss << pcount << "\n";
    oss << time << " " << currentStep << " " << saveCount << " " << std::setprecision(std::numeric_limits<float>::digits10 + 1) << grid.size * cpuSim.outputScale << "\n";
    char p[5000];
    for (int i = 0; i < pcount; i++)
    {
        sprintf(p, "%d %d %.*e %.*e %.*e %d %.*e\n", id_s[i], parts_s[i].t, cpuSim.savePrecision, cpuSim.outputScale * parts_s[i].p.x,
                cpuSim.savePrecision, cpuSim.outputScale * parts_s[i].p.y, cpuSim.savePrecision, cpuSim.outputScale * parts_s[i].p.z,
                0, cpuSim.savePrecision, cpuSim.outputScale * parts_s[i].r);
        oss << p;
    }
    std::string data = oss.str();
    std::ofstream file(path);
    if (!file.good())
    {
        std::cout << "Can't open: " << path << " Exiting..." << std::endl;
        exit(0);
    }
    file << data;
    file.close();
    saveCount++;
}

void SimulatorMultiPart::readXYZ(std::string path)
{
    std::ifstream file(path);
    std::string line;
    std::getline(file, line);
    std::getline(file, line);
    std::istringstream iss(line);
    iss >> time;
    iss >> currentStep;
    iss >> saveCount;
    iss >> grid.size;
    int none;
    for (int i = 0; i < pcount; i++)
    {
        std::getline(file, line);
        std::istringstream tmp(line);
        tmp >> id_s[i];
        tmp >> parts_s[i].t;
        tmp >> parts_s[i].p.x;
        tmp >> parts_s[i].p.y;
        tmp >> parts_s[i].p.z;
        tmp >> none;
        tmp >> parts_s[i].r;
        parts_s[i].r /= cpuSim.outputScale;
        parts_s[i].p.x /= cpuSim.outputScale;
        parts_s[i].p.y /= cpuSim.outputScale;
        parts_s[i].p.z /= cpuSim.outputScale;
        parts[i] = parts_s[i];
        id_s[i] = id[i];
    }
    file.close();
}

void SimulatorMultiPart::saveBin(std::string name)
{
    std::ostringstream oss;
    oss.write(reinterpret_cast<char *>(&currentStep), sizeof(long int));
    oss.write(reinterpret_cast<char *>(&pcount), sizeof(int));
    oss.write(reinterpret_cast<char *>(&time), sizeof(float));
    oss.write(reinterpret_cast<char *>(&saveCount), sizeof(int));
    oss.write(reinterpret_cast<char *>(&grid.size), sizeof(float));

    for (int i = 0; i < pcount; i++)
    {
        oss.write(reinterpret_cast<char *>(&parts_s[i].t), sizeof(int));
        oss.write(reinterpret_cast<char *>(&parts_s[i].r), sizeof(float));
        oss.write(reinterpret_cast<char *>(&parts_s[i].p.x), sizeof(double));
        oss.write(reinterpret_cast<char *>(&parts_s[i].p.y), sizeof(double));
        oss.write(reinterpret_cast<char *>(&parts_s[i].p.z), sizeof(double));
        oss.write(reinterpret_cast<char *>(&id_s[i]), sizeof(int));
    }
    std::string path = cpuSim.outputDir + "/XYZ/" + name;
    std::ofstream file(path, std::ios::binary);
    if (!file.good())
    {
        std::cout << "Can't open: " << path << " Exiting..." << std::endl;
        exit(0);
    }
    std::string data = oss.str();
    file.write(data.c_str(), data.size());
    file.close();
}

void SimulatorMultiPart::readBin(std::string path)
{
    std::cout << "Start from file :" << path << "\n";
    std::ifstream file(path);
    if (!file)
    {
        std::cout << "Cant read " << path << "\n";
        exit(0);
    }
    else
    {
        file.read(reinterpret_cast<char *>(&currentStep), sizeof(long int));
        file.read(reinterpret_cast<char *>(&pcount), sizeof(int));
        file.read(reinterpret_cast<char *>(&time), sizeof(float));
        file.read(reinterpret_cast<char *>(&saveCount), sizeof(int));
        file.read(reinterpret_cast<char *>(&grid.size), sizeof(float));

        for (int i = 0; i < pcount; i++)
        {
            file.read(reinterpret_cast<char *>(&parts_s[i].t), sizeof(int));
            file.read(reinterpret_cast<char *>(&parts_s[i].r), sizeof(float));
            file.read(reinterpret_cast<char *>(&parts_s[i].p.x), sizeof(double));
            file.read(reinterpret_cast<char *>(&parts_s[i].p.y), sizeof(double));
            file.read(reinterpret_cast<char *>(&parts_s[i].p.z), sizeof(double));
            file.read(reinterpret_cast<char *>(&id_s[i]), sizeof(int));
        }
        file.close();
        myMemcpy<<<pcount / 32 + 1, 32>>>(pcount, parts_s, parts);
        myMemcpy<<<pcount / 32 + 1, 32>>>(pcount, id_s, id);
        check_gpu_err(__FILE__, __LINE__);
    }
}

void SimulatorMultiPart::fillPsis()
{
    for (int i = 0; i < 128; i++)
        psis[i] = 0.f;
    for (const auto &p : Parts)
        psis[p.type] = p.psi;
}

float SimulatorMultiPart::getRc()
{
    float rc = std::numeric_limits<float>::min();

    for (auto &f : Yukawas)
        rc = std::max(rc, f.data.Rc);
    for (auto &f : Lennard_Joness)
        rc = std::max(rc, f.data.Rc);
    for (auto &f : DLVOs)
        rc = std::max(rc, f.data.Rc);

    return rc;
}

SimulatorMultiPart::SimulatorMultiPart(simulation_t mySim, std::vector<particles_t> myParticles, std::vector<Yukawa> _Yukawas, std::vector<Lennard_Jones> _Lennard_Joness, std::vector<DLVO> _DLVOs)
{
    Yukawas = _Yukawas;
    Lennard_Joness = _Lennard_Joness;
    DLVOs = _DLVOs;

    Parts = myParticles;
    cpuSim = mySim;
    gridRebuildCount = 0;
    simIt = 0;
    cudaMallocManaged(&psis, sizeof(float) * 128);
    fillPsis();
    opttimer = new myTimer("opt", false, 1000);

    cpuSim.nbrColloids = 0;
    for (auto parts : myParticles)
        cpuSim.nbrColloids += parts.number;
    pcount = cpuSim.nbrColloids;
    listSize = cpuSim.maxListSize;
    std::cout << "Number of particles:" << pcount << "\n";
    initPbuffers();
    bool hasRestart = false;
    bool hasRestartBin = testFile(cpuSim.outputDir + "XYZ/" + "restart.bin");

    setupKernelRandom<<<pcount / 32 + 1, 32>>>(rndBro, pcount);
    check_gpu_err(__FILE__, __LINE__);

    if (cpuSim.restartFile != "" || hasRestartBin)
    {
        std::string toLoad;
        if (hasRestartBin)
            toLoad = cpuSim.outputDir + "XYZ/" + "restart.bin";
        else
            toLoad = cpuSim.restartFile;

        if (testFile(toLoad))
        {
            std::string ext = toLoad.substr(toLoad.length() - 3);
            if (ext == "bin")
            {
                readBin(toLoad);
                hasRestart = true;
                std::cout << "From iteration :" << currentStep << " at " << time << " s\n";
            }
            else if (ext == "xyz")
            {
                readXYZ(toLoad);
                hasRestart = true;
            }
            else
            {
                std::cout << "Error :restart file " << toLoad << " not supported.\n";
                exit(0);
            }
        }
        else
        {
            std::cout << "File " << toLoad << " not found.\n";
            exit(0);
        }
    }
    float biggestRadius;
    if (hasRestart)
        biggestRadius = getMaxRadius();
    else
        biggestRadius = genRadius();

    const float v = computeVolume();
    const float vBox = v / cpuSim.density;
    const float size = pow(vBox, 1.f / 3.f);
    grid.size = size;
    grid.Rc = getRc();
    grid.Rl = grid.Rc * cpuSim.Rc_multiplicator;
    grid.cellPerAxis = floor(grid.size / (grid.Rl * biggestRadius));
    grid.cellSize = grid.size / grid.cellPerAxis;
    grid.nbCells = grid.cellPerAxis * grid.cellPerAxis * grid.cellPerAxis;
    printf("biggestRadius %e\n", biggestRadius);
    grid.print("");
    nbCell = grid.nbCells;
    if (!hasRestart)
    {
        initPositions();
        currentStep = 0;
        time = 0.f;
        saveCount = 0;
    }
    cudaMallocManaged(&grids, sizeof(int2) * nbCell);
    check_gpu_err(__FILE__, __LINE__);
    update[0] = true;
    fluctuatingForceFactor = (2.f * Constant::kb * cpuSim.temperature * cpuSim.deltaTime) / (6.f * cpuSim.eta * Constant::PI);
    factorSumForce = cpuSim.deltaTime / (6 * cpuSim.eta * Constant::PI);
    verletSecure = (grid.Rl - grid.Rc) / 2.f * biggestRadius;
    saveBin("start.bin");
    int garbadge, limit;
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, generate_normal_kernelr, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, myMemcpy<particle>, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, myMemcpy<int>, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, myMemset<double3>, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, myMemset<int>, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, findCellBounds, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, fillList, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, apply, 0, pcount);
    optimizers.push_back(optimizer(limit, true));
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, calcHash, 0, pcount);
    optimizers.push_back(optimizer(limit, true));

    int best = 2048;
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, computeForceT<Yukawa>, 0, pcount);
    best = std::min(best, limit);
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, computeForceT<Lennard_Jones>, 0, pcount);
    best = std::min(best, limit);
    cudaOccupancyMaxPotentialBlockSize(&garbadge, &limit, computeForceT<DLVO>, 0, pcount);
    best = std::min(best, limit);
    optimizers.push_back(optimizer(best, true));
}

void SimulatorMultiPart::initPositions()
{
    UniformDistribution localRnd = UniformDistribution(0.f, grid.size, cpuSim.randomSeedPosition);
    std::vector<int> *rndgrid = new std::vector<int>[grid.nbCells];
    int currentCell;
    for (int i = 0; i < cpuSim.nbrColloids; i++)
    {
        bool good = false;
        do
        {
            parts[i].p = make_double3(localRnd.getValue(), localRnd.getValue(), localRnd.getValue());
            int3 tmpCell;
            tmpCell.x = floor(parts[i].p.x / grid.cellSize);
            tmpCell.y = floor(parts[i].p.y / grid.cellSize);
            tmpCell.z = floor(parts[i].p.z / grid.cellSize);
            currentCell = getCell1D(tmpCell, grid.cellPerAxis);
            if (i == 0)
                good = true;
            else
            {
                good = true;
                int3 tmpCell;
                tmpCell.x = floor(parts[i].p.x / grid.cellSize);
                tmpCell.y = floor(parts[i].p.y / grid.cellSize);
                tmpCell.z = floor(parts[i].p.z / grid.cellSize);
                currentCell = getCell1D(tmpCell, grid.cellPerAxis);
                for (int dx = -1; dx < 2; dx++)
                    for (int dy = -1; dy < 2; dy++)
                        for (int dz = -1; dz < 2; dz++)
                        {
                            int x = getCoordPbc(dx + tmpCell.x, grid.cellPerAxis);
                            int y = getCoordPbc(dy + tmpCell.y, grid.cellPerAxis);
                            int z = getCoordPbc(dz + tmpCell.z, grid.cellPerAxis);
                            int current = getCell1D(x, y, z, grid.cellPerAxis);
                            for (const auto &j : rndgrid[current])
                            {
                                double3 diff = computeVector(parts[i].p, parts[j].p, grid.size);
                                double d = sqrt(diff.x * diff.x + diff.y * diff.y + diff.z * diff.z);
                                double rr = parts[i].r + parts[j].r;
                                if (d < rr * cpuSim.overlapDistance)
                                {
                                    good = false;
                                }
                            }
                        }
            }
        } while (!good);
        rndgrid[currentCell].push_back(i);
    }
    std::memcpy(parts_s, parts, cpuSim.nbrColloids * sizeof(particle));
}

float SimulatorMultiPart::getMaxRadius()
{
    float maxRadius = std::numeric_limits<float>::min();

    for (int i = 0; i < pcount; i++)
        maxRadius = std::max(maxRadius, parts[i].r);
    return maxRadius;
}

float SimulatorMultiPart::genRadius()
{
    curandGenerator_t rand_generator_host;
    curandRngType_t generator_type = CURAND_RNG_PSEUDO_DEFAULT;
    curandCreateGeneratorHost(&rand_generator_host, generator_type);
    curandSetPseudoRandomGeneratorSeed(rand_generator_host, cpuSim.randomSeedPosition);
    float *tmpVal = new float[2];

    int count = 0;
    for (auto p : Parts)
    {
        float min = p.minRadius != 0.0 ? p.minRadius : FLT_MIN;
        float max = p.maxRadius != 0.0 ? p.maxRadius : FLT_MAX;
        float *tmpVal = new float[2];
        for (int i = 0; i < p.number; i++)
        {
            float tmp;
            do
            {
                if (p.distribution == "normal")
                {
                    curandGenerateNormal(rand_generator_host, tmpVal, 2, 0.0f, 1.0f);
                    tmp = exp(p.std_dev * tmpVal[0] + log(p.radius));
                }

                if (p.distribution == "lognormal")
                {
                    curandGenerateLogNormal(rand_generator_host, tmpVal, 2, 0.0f, float(p.std_dev));
                    tmp = p.radius * tmpVal[0];
                }
            } while (tmp > max || tmp < min);
            parts[count].r = tmp;
            id[count] = count;
            parts[count].t = p.type;
            count++;
        }
    }

    for (int i = 0; i < cpuSim.nbrColloids; i++)
    {
        parts_s[i] = parts[i];
        id_s[i] = id[i];
    }
    return getMaxRadius();
}

void SimulatorMultiPart::initPbuffers()
{
    cudaMallocManaged(&parts, sizeof(particle) * pcount);
    cudaMallocManaged(&parts_s, sizeof(particle) * pcount);
    cudaMallocManaged(&tid, sizeof(int) * pcount);
    cudaMallocManaged(&id, sizeof(int) * pcount);
    cudaMallocManaged(&id_s, sizeof(int) * pcount);
    cudaMallocManaged(&forceSum, sizeof(double3) * pcount);
    cudaMallocManaged(&move, sizeof(float3) * pcount);
    cudaMallocManaged(&hash, sizeof(uint64_t) * pcount);
    cudaMallocManaged(&rndBro, sizeof(curandState) * pcount);
    cudaMallocManaged(&randomsV3, sizeof(double3) * pcount);
    cudaMallocManaged(&counters, sizeof(int) * pcount);
    cudaMallocManaged(&nlist, sizeof(int) * pcount * listSize);
    cudaMallocManaged(&verletParts, sizeof(particle) * pcount * listSize);
    cudaMallocManaged(&update, sizeof(bool));

    dev_hash = thrust::device_pointer_cast(hash);
    dev_tid = thrust::device_pointer_cast(tid);
}

void SimulatorMultiPart::computeNext()
{
    simIt++;
    if (cpuSim.savePeriod == 0 && logSave(currentStep))
        save();
    if (cpuSim.savePeriod >= 1 && currentStep % cpuSim.savePeriod == 0)
        save();
    if (cpuSim.restartSaveInterval > 0 && currentStep % cpuSim.restartSaveInterval == 0)
        saveBin("restart.bin");
    if (update[0])
        updateGrid();

    for (const auto &f : Yukawas)
    {
        opttimer->start();
        computeForceT<<<pcount / optimizers.back().getThreads() + 1, optimizers.back().getThreads()>>>(pcount, parts_s, nlist, listSize, counters, forceSum, grid, f, psis);
        check_gpu_err(__FILE__, __LINE__);
        opttimer->stop();
        optimizers.back().updateTimming(opttimer->getLast());
    }
    for (const auto &f : Lennard_Joness)
    {
        opttimer->start();
        computeForceT<<<pcount / optimizers.back().getThreads() + 1, optimizers.back().getThreads()>>>(pcount, parts_s, nlist, listSize, counters, forceSum, grid, f, psis);
        check_gpu_err(__FILE__, __LINE__);
        opttimer->stop();
        optimizers.back().updateTimming(opttimer->getLast());
    }
    for (const auto &f : DLVOs)
    {
        opttimer->start();
        computeForceT<<<pcount / optimizers.back().getThreads() + 1, optimizers.back().getThreads()>>>(pcount, parts_s, nlist, listSize, counters, forceSum, grid, f, psis);
        check_gpu_err(__FILE__, __LINE__);
        opttimer->stop();
        optimizers.back().updateTimming(opttimer->getLast());
    }

    generate_normal_kernelr<<<pcount / optimizers[0].getThreads() + 1, optimizers[0].getThreads()>>>(rndBro, randomsV3, pcount);
    check_gpu_err(__FILE__, __LINE__);
    ;
    opttimer->stop();
    optimizers[0].updateTimming(opttimer->getLast());

    opttimer->start();
    apply<<<pcount / optimizers[7].getThreads() + 1, optimizers[7].getThreads()>>>(pcount, verletSecure, parts_s, randomsV3, forceSum, move, grid, fluctuatingForceFactor, factorSumForce, update);
    check_gpu_err(__FILE__, __LINE__);
    ;
    optimizers[7].updateTimming(opttimer->getLast());

    currentStep++;
    time = currentStep * cpuSim.deltaTime;
}

void SimulatorMultiPart::updateGrid()
{
    gridRebuildCount++;
    opttimer->start();
    myMemcpy<<<pcount / optimizers[1].getThreads() + 1, optimizers[1].getThreads()>>>(pcount, parts_s, parts);
    check_gpu_err(__FILE__, __LINE__);
    opttimer->stop();
    optimizers[1].updateTimming(opttimer->getLast());

    opttimer->start();
    myMemcpy<<<pcount / optimizers[2].getThreads() + 1, optimizers[2].getThreads()>>>(pcount, id_s, id);
    check_gpu_err(__FILE__, __LINE__);
    opttimer->stop();
    optimizers[2].updateTimming(opttimer->getLast());
    opttimer->start();
    myMemset<float3><<<pcount / optimizers[3].getThreads() + 1, optimizers[3].getThreads()>>>(pcount, move, make_float3(0.f, 0.f, 0.f));
    check_gpu_err(__FILE__, __LINE__);
    opttimer->stop();
    optimizers[3].updateTimming(opttimer->getLast());

    opttimer->start();
    myMemset<int2><<<nbCell / optimizers[4].getThreads() + 1, optimizers[4].getThreads()>>>(nbCell, grids, make_int2(2147483647 - 32, -1)); // Care about it or index could become negative
    check_gpu_err(__FILE__, __LINE__);
    opttimer->stop();
    optimizers[4].updateTimming(opttimer->getLast());

    ///// care
    opttimer->start();
    calcHash<<<pcount / optimizers[8].getThreads() + 1, optimizers[8].getThreads()>>>(pcount, grid, hash, tid, parts);
    check_gpu_err(__FILE__, __LINE__);
    opttimer->stop();
    optimizers[8].updateTimming(opttimer->getLast());

    thrust::sort_by_key(thrust::device, dev_hash, dev_hash + pcount, dev_tid);

    opttimer->start();
    findCellBounds<<<pcount / optimizers[5].getThreads() + 1, optimizers[5].getThreads()>>>(grid, pcount, grids, hash, tid, parts, parts_s, id, id_s);
    check_gpu_err(__FILE__, __LINE__);
    ;
    opttimer->stop();
    optimizers[5].updateTimming(opttimer->getLast());

    opttimer->start();
    fillList<<<pcount / optimizers[6].getThreads() + 1, optimizers[6].getThreads()>>>(pcount, listSize, counters, parts_s, grids, nlist, grid);
    check_gpu_err(__FILE__, __LINE__);
    ;
    opttimer->stop();
    optimizers[6].updateTimming(opttimer->getLast());

    update[0] = false;
}

float SimulatorMultiPart::computeVolume()
{
    float v = 0;
    for (int i = 0; i < cpuSim.nbrColloids; i++)
        v += 4.f / 3.f * M_PI * parts[i].r * parts[i].r * parts[i].r;
    return v;
}