#include <cmath>
#include <curand.h>
#include <curand_kernel.h>
#include <cuda_runtime.h>
#include <cfloat>

#include "../Commons/utilityKernel.cuh"

/**
 * @brief Initializes the random number generator states for each thread.
 *
 * This function initializes the random number generator states for each thread
 * using the CURAND library. Each thread's state is initialized with a different seed
 * based on its thread ID.
 *
 * @param state Pointer to the array of curandState structures representing the RNG states.
 * @param nbr The total number of threads.
 */
__global__ void setupKernelRandom(curandState *state, const int nbr)
{
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    if (id < nbr)
        curand_init(7 + id, id, 0, &state[id]);
}

__global__ void genGPUR(bool type, const int n, curandState *states, float *radius, const float min, const float max, const float mean, const float dev)
{
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    if (id < n)
    {
        curandState local = states[id];
        float lmax = max == 0.0f ? FLT_MAX : max;
        float lmin = min == 0.0f ? FLT_MIN : min;
        float r;
        do
        {
            if (type)
            {
                r = curand_log_normal(&local, mean, dev);
                printf("log normal r: %e mean %e dev %e\n", r, mean, dev);
            }
            else
            {
                r = curand_normal(&local);
                r = mean + r * dev;
                printf("normal r: %e mean %e dev %e\n", r, mean, dev);
            }
        } while (r > lmin || r < lmax);
        radius[id] = r;
        states[id] = local;
    }
}

/**
 * @brief Computes the vector between two points in a periodic space.
 *
 * This function computes the vector between two points in a periodic space,
 * taking into account the periodic boundary conditions. The function subtracts
 * the second point from the first point and adjusts the result based on the size
 * of the periodic space.
 *
 * @param p The first point.
 * @param n The second point.
 * @param size The size of the periodic space.
 * @return The computed vector between the two points, adjusted for periodic boundaries.
 */
static __host__ __device__ __inline__ double3 computeVector(const double3 p, const double3 n, const float size)
{
    double3 tmp = p - n;
    if (fabs(tmp.x - size) < fabs(tmp.x))
        tmp.x -= size;
    else if (fabs(tmp.x + size) < fabs(tmp.x))
        tmp.x += size;

    if (fabs(tmp.y - size) < fabs(tmp.y))
        tmp.y -= size;
    else if (fabs(tmp.y + size) < fabs(tmp.y))
        tmp.y += size;

    if (fabs(tmp.z - size) < fabs(tmp.z))
        tmp.z -= size;
    else if (fabs(tmp.z + size) < fabs(tmp.z))
        tmp.z += size;
    return tmp;
}

/**
 * @brief Calculates the 1D index of a 3D grid cell.
 *
 * This function calculates the 1D index of a 3D grid cell based on its 3D grid coordinates and the number of cells per dimension.
 *
 * @param p The 3D grid cell coordinates.
 * @param cells The number of cells per dimension.
 * @return The 1D index of the grid cell.
 */
static __host__ __device__ __inline__ int getCell1D(const int3 p, const int cells)
{
    return p.x + p.y * cells + p.z * cells * cells;
}

/**
 * @brief Calculates the periodic boundary condition (PBC) corrected coordinate.
 *
 * This function calculates the periodic boundary condition (PBC) corrected coordinate
 * for a given coordinate value and the number of cells.
 *
 * @param i The coordinate value.
 * @param cells The number of cells.
 * @return The PBC corrected coordinate.
 */

static __host__ __device__ __inline__ int getCoordPbc(const int i, const int cells)
{
    return (i + cells) % cells;
}

/**
 * @brief Computes the one-dimensional cell index based on the given three-dimensional coordinates.
 *
 * This function calculates the one-dimensional index of a cell in a three-dimensional grid based on the provided
 * coordinates and the total number of cells in each dimension.
 *
 * @param x The X-coordinate of the cell.
 * @param y The Y-coordinate of the cell.
 * @param z The Z-coordinate of the cell.
 * @param cells The total number of cells in each dimension.
 *
 * @return The one-dimensional index of the cell.
 */

static __host__ __device__ __inline__ int getCell1D(const int x, const int y, const int z, const int cells)
{
    return x + y * cells + z * cells * cells;
}

/**
 * @brief Finds the cell boundaries for particles in a grid.
 *
 * This function finds the cell boundaries for particles in a grid.
 * For each particle, it determines the corresponding cell based on its hash value.
 * The function then updates the start and stop indices for the cell using atomic operations.
 * Additionally, it rearranges the particle data (position, radius, gid, type) into sorted arrays based on the particle indices.
 *
 * @param g The grid parameters.
 * @param count The total number of particles.
 * @param start Output array to store the start indices for each cell.
 * @param stop Output array to store the stop indices for each cell.
 * @param hash Array of particle hash values.
 * @param index Array of particle indices.
 * @param pos Array of particle positions.
 * @param pos_sorted Output array to store the sorted particle positions.
 * @param radius Array of particle radii.
 * @param radius_sort Output array to store the sorted particle radii.
 * @param gid Array of particle global IDs.
 * @param gid_sort Output array to store the sorted particle global IDs.
 * @param type Array of particle types.
 * @param type_s Output array to store the sorted particle types.
 */

/**
 * @brief Fills the neighbor list for each particle in a grid.
 *
 * This function fills the neighbor list for each particle in a grid.
 * It iterates over the particles and checks their neighboring cells to find potential neighbors.
 * The function calculates the distance between each particle and its potential neighbor,
 * and if the distance is within the cutoff range, the neighbor is added to the list.
 * The count of neighbors for each particle is stored in the counters array.
 *
 * @param nbr The total number of particles.
 * @param lsize The size of the neighbor list.
 * @param counters Array to store the count of neighbors for each particle.
 * @param pos Array of particle positions.
 * @param start Array of start indices for each cell.
 * @param stop Array of stop indices for each cell.
 * @param verlets Output array to store the neighbor list.
 * @param radius Array of particle radii.
 * @param g The grid parameters.
 */

/**
 * @brief Generate random vectors from a normal distribution.
 *
 * This function generates random vectors from a normal distribution using the curand library.
 * Each thread generates a random vector and stores it in the output array.
 * The curandState array is used to maintain the state of the random number generator for each thread.
 *
 * @param state Array of curandState structures to maintain the state of the random number generator for each thread.
 * @param out Array to store the generated random vectors.
 * @param nbr The total number of random vectors to generate.
 */
__global__ void generate_normal_kernelr(curandState *state, double3 *out, const int nbr)
{
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    if (id < nbr)
    {
        curandState localState = state[id];
        const double x = curand_normal_double(&localState);
        const double y = curand_normal_double(&localState);
        const double z = curand_normal_double(&localState);
        state[id] = localState;
        out[id] = make_double3(x, y, z);
    }
}

static inline __device__ double3 applyPBC(const double3 p, const float size)
{
    double3 tmp = p;
    if (tmp.x < 0.0)
        tmp.x = tmp.x + size;
    if (tmp.x >= size)
        tmp.x = tmp.x - size;
    if (tmp.y < 0.0)
        tmp.y = tmp.y + size;
    if (tmp.y >= size)
        tmp.y = tmp.y - size;
    if (tmp.z < 0.0)
        tmp.z = tmp.z + size;
    if (tmp.z >= size)
        tmp.z = tmp.z - size;
    return tmp;
}

__global__ void apply(const int nbr, const float verletSecure, particle *parts, const double3 *randoms, double3 *forces, float3 *move, const grid_t g, const float fluctuatingForceFactor, const float factorSumForce, bool *update)
{
    const int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if (idx < nbr)
    {
        particle p = parts[idx];
        const float fff = sqrt(fluctuatingForceFactor / p.r);
        const float fsf = factorSumForce / p.r;
        double3 m = make_double3((double)fsf * forces[idx].x + randoms[idx].x * (double)fff, (double)fsf * forces[idx].y + randoms[idx].y * (double)fff, (double)fsf * forces[idx].z + randoms[idx].z * (double)fff);
        p.p.x += m.x;
        p.p.y += m.y;
        p.p.z += m.z;
        p.p = applyPBC(p.p, g.size);
        float3 lmove = make_float3(m.x + move[idx].x, m.y + move[idx].y, m.z + move[idx].z);
        parts[idx] = p;
        move[idx] = lmove;
        if (lmove.x * lmove.x + lmove.y * lmove.y + lmove.z * lmove.z > verletSecure * verletSecure)
            update[0] = true;
        forces[idx] = make_double3(0., 0., 0.);
    }
}

template <typename T>
__global__ void computeForceT(const int pcount, const particle *p, const int *nlist, const int lsize, const int *counters, double3 *forces, const grid_t g, T F, float *psis)
{
    const int id = (threadIdx.x + blockIdx.x * blockDim.x);
    if (id < pcount)
    {
        const int lcount = counters[id];
        const particle localp = p[id];
        const int count = lcount;
        double3 res = make_double3(0., 0., 0.);
        const int start = id * lsize;
        const int stop = id * lsize + count;
#pragma unroll
        for (int i = start; i < stop; i += 1)
        {
            const int index = nlist[i];
            //       for(int i = 0; i<pcount;i+=1)
            //        {
            //            const int index = i;
            const particle p2 = p[index];
            if (!((F.data.first == p2.t && F.data.second == localp.t) || (F.data.second == p2.t && F.data.first == localp.t) || (F.data.first == -1 && F.data.second == -1)))
                continue;
            const double3 dir = computeVector(localp.p, p2.p, g.size);
            const float d = sqrt(dir.x * dir.x + dir.y * dir.y + dir.z * dir.z);
            const float R = localp.r + p2.r;

            if (d < g.Rc * R / 2.f)
            {
                double3 tmp = F.compute(localp.r, p2.r, d, dir, psis[localp.t], psis[p2.t]);
                res += tmp;
            }
        }
        forces[id] += res;
    }
}

__global__ void findCellBounds(const grid_t g, const int count, int2 *grids, const uint64_t *hash, const int *index, const particle *p, particle *ps, const int *gid, int *gid_sort)
{
    const int idx = threadIdx.x + (blockDim.x * blockIdx.x);
    if (idx < count)
    {
        const int3 cellId = Z_decode2(hash[idx]);
        const int cell = getCell1D(cellId.x, cellId.y, cellId.z, g.cellPerAxis);
        const int localId = index[idx];
        if (cell > g.nbCells)
            printf("Error  %d %d %d", cellId.x, cellId.y, cellId.z);
        atomicMin(&grids[cell].x, idx);
        atomicMax(&grids[cell].y, idx);
        gid_sort[idx] = gid[localId];
        ps[idx] = p[localId];
    }
}

__global__ void calcHash(const int pcount, const grid_t g, uint64_t *out_hash, int *out_index, const particle *__restrict__ p)
{
    const int id = threadIdx.x + (blockDim.x * blockIdx.x);
    if (id < pcount)
    {
        const particle localp = p[id];
        const int3 index_cell = getCell(localp.p, g.cellSize, g.cellPerAxis);
        out_index[id] = id;
        out_hash[id] = Z_encode2(index_cell);
    }
}

__global__ void fillList(const int nbr, const int lsize, int *counters, const particle *pos, const int2 *grids, int *verlets, const grid_t g)
{
    const int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if (idx < nbr)
    {
        int count = 0;
        const particle p = pos[idx];
        const int3 pl = getCell(p.p, g.cellSize, g.cellPerAxis);
#pragma unroll
        for (int dx = -1; dx < 2; dx++)
        {
            const int x = (pl.x + dx + g.cellPerAxis) % g.cellPerAxis;
#pragma unroll
            for (int dy = -1; dy < 2; dy++)
            {
                const int y = (pl.y + dy + g.cellPerAxis) % g.cellPerAxis;
#pragma unroll
                for (int dz = -1; dz < 2; dz++)
                {
                    const int z = (pl.z + dz + g.cellPerAxis) % g.cellPerAxis;
                    const int current = getCell1D(x, y, z, g.cellPerAxis);
                    const int lstart = grids[current].x;
                    const int lstop = grids[current].y;
#pragma unroll
                    for (int i = lstart; i <= lstop; i++)
                    {
                        particle tmpp = pos[i];
                        const double3 tmp = computeVector(p.p, tmpp.p, g.size);
                        const float d = sqrt(tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z);

                        if (d < g.Rl * (p.r + tmpp.r) / 2.f && i != idx)
                        {
                            verlets[idx * lsize + count] = i;
                            count++;
                            assert(count < lsize);
                        }
                    }
                }
            }
        }
        counters[idx] = count;
    }
}