#pragma once

#include <math.h>
#include <curand.h>
#include <cstdint> // include this header for uint64_t
#include <curand_kernel.h>
#include <cuda_runtime.h>

/**
 * @brief Checks for and prints any errors that occurred during a CUDA kernel launch or operation.
 *
 * @details This function first calls cudaDeviceSynchronize() to ensure that all previously launched
 * kernels have completed before checking for errors. It then calls cudaGetLastError() to get the
 * last error that occurred during a CUDA operation. If an error occurred, it prints the error name
 * and description to the standard error stream using std::cerr.
 *
 * @return void
 */
static void check_gpu_err(const char *file, int line)
{
    cudaDeviceSynchronize();
    auto err = cudaGetLastError();
    if (err != cudaSuccess)
    {
        std::cerr << "Error at line " << line << " in file " << file << ": "
                  << cudaGetErrorName(err) << ": " << cudaGetErrorString(err) << std::endl;
        exit(0);
    }
}

/**
 * @brief Formats an account number as an 8-digit string.
 *
 * @details This function takes a long integer account number and formats it as an 8-digit string,
 * padded with leading zeros if necessary. It uses std::snprintf() to format the account number,
 * storing the result in a character buffer of size 15. The formatted string is then returned
 * as a std::string object.
 *
 * @param acct_no The account number to format.
 *
 * @return A string representation of the formatted account number.
 */
static std::string format_account_number(long int acct_no)
{
    char buffer[15];
    std::snprintf(buffer, sizeof(buffer), "%08ld", acct_no);
    return buffer;
}

/**
 * @brief Checks if a given number is a power of 10.
 *
 *
 * @param nb The number to check.
 *
 * @return True if the number is a power of 10, false otherwise.
 */
static bool logSave(long int nb)
{
    double t = log10(nb);
    int tt = int(t);
    long int ttt = int(pow(10, tt)) / 4;
    if (ttt == 0)
        ttt = 1;
    // long int tmp = nb / ttt;
    // bool res = tmp == 0;
    return nb % ttt == 0;
}

// This function takes in a long integer nb and an integer interval and returns a boolean value.
// It checks if the nb is divisible by interval without a remainder. If so, it returns true, otherwise, it returns false.
// The function essentially checks if nb is a multiple of interval.
static bool linearSave(long int nb, int interval)
{
    return nb % (long int)interval == 0;
}

/**
 *
 *    @brief This kernel sets the value of each element in a buffer to a given value.
 *    @tparam T The data type of the buffer elements.
 *    @param nbr The number of elements in the buffer.
 *    @param buff The pointer to the buffer.
 *    @param value The value to set each element of the buffer to.
 */
template <typename T>
__global__ void myMemset(int nbr, T *buff, T value)
{
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    if (id < nbr)
        buff[id] = value;
}

/**
 * @brief This kernel copies the contents of one buffer to another buffer.
 *
 * @tparam T The data type of the buffer elements.
 * @param nbr The number of elements to copy from the source buffer to the destination buffer.
 * @param src The pointer to the source buffer.
 * @param dst The pointer to the destination buffer.
 */
template <typename T>
__global__ void myMemcpy(int nbr, T *src, T *dst)
{
    const int id = threadIdx.x + blockIdx.x * blockDim.x;
    if (id < nbr)
        dst[id] = src[id];
}

static inline __host__ __device__ int3 make_int3(const int a)
{
    return make_int3(a, a, a);
}

static inline __host__ __device__ double3 make_double3(const double a)
{
    return make_double3(a, a, a);
}

static inline __host__ __device__ double3 operator-(const double3 a)
{
    return make_double3(-a.x, -a.y, -a.z);
}

static inline __host__ __device__ double3 operator+(const double3 a, const double3 b)
{
    return make_double3(a.x + b.x, a.y + b.y, a.z + b.z);
}

static inline __host__ __device__ double3 operator-(const double3 a, const double3 b)
{
    return make_double3(a.x - b.x, a.y - b.y, a.z - b.z);
}

static inline __host__ __device__ double3 operator*(const double3 a, const double3 b)
{
    return make_double3(a.x * b.x, a.y * b.y, a.z * b.z);
}

static inline __host__ __device__ double3 operator/(const double3 a, const double3 b)
{
    return make_double3(a.x / b.x, a.y / b.y, a.z / b.z);
}

static inline __host__ __device__ void operator+=(double3 &a, const double3 b)
{
    a.x += b.x;
    a.y += b.y;
    a.z += b.z;
}

static inline __host__ __device__ void operator-=(double3 &a, const double3 b)
{
    a.x -= b.x;
    a.y -= b.y;
    a.z -= b.z;
}

static inline __host__ __device__ void operator*=(double3 &a, const double3 b)
{
    a.x *= b.x;
    a.y *= b.y;
    a.z *= b.z;
}

static inline __host__ __device__ void operator/=(double3 &a, const double3 b)
{
    a.x /= b.x;
    a.y /= b.y;
    a.z /= b.z;
}

static inline __host__ __device__ double3 operator*(double3 a, const double b)
{
    return make_double3(a.x * b, a.y * b, a.z * b);
}

static inline __host__ __device__ double3 operator*=(double3 a, const double b)
{
    return make_double3(a.x * b, a.y * b, a.z * b);
}

static inline __host__ __device__ double3 operator*(double3 a, const int b)
{
    return make_double3(a.x * b, a.y * b, a.z * b);
}

static inline __host__ __device__ double3 operator/(double3 a, const double b)
{
    double scale = 1.0 / b;
    return a * scale;
}

static inline __host__ __device__ void operator*=(double3 &a, const double b)
{
    a.x *= b;
    a.y *= b;
    a.z *= b;
}

static inline __host__ __device__ void operator*=(double3 &a, const int b)
{
    a.x *= b;
    a.y *= b;
    a.z *= b;
}

static inline __host__ __device__ void operator/=(double3 &a, const int b)
{
    a.x *= b;
    a.y *= b;
    a.z *= b;
}

static inline __host__ __device__ float size(const double3 a)
{
    return sqrt(a.x * a.x + a.y * a.y + a.z * a.z);
}

static inline __host__ __device__ int3 getCell(const double3 p, const float s, const int cellPerAxis)
{
    int x = floor(p.x / s);
    int y = floor(p.y / s);
    int z = floor(p.z / s);
    x = (x == cellPerAxis) ? x - 1 : x;
    y = (y == cellPerAxis) ? y - 1 : y;
    z = (z == cellPerAxis) ? z - 1 : z;
    return make_int3(x, y, z);
}

static inline __host__ __device__ double dot(const double3 a, const double3 b)
{
    return a.x * b.x + a.y * b.y + a.z * b.z;
}

static inline __device__ uint64_t spread(uint64_t w)
{
    w &= 0x00000000001fffff;
    w = (w | w << 32) & 0x001f00000000ffff;
    w = (w | w << 16) & 0x001f0000ff0000ff;
    w = (w | w << 8) & 0x010f00f00f00f00f;
    w = (w | w << 4) & 0x10c30c30c30c30c3;
    w = (w | w << 2) & 0x1249249249249249;
    return w;
};

static inline __device__ uint32_t compact(uint64_t w)
{
    w &= 0x1249249249249249;
    w = (w ^ (w >> 2)) & 0x30c30c30c30c30c3;
    w = (w ^ (w >> 4)) & 0xf00f00f00f00f00f;
    w = (w ^ (w >> 8)) & 0x00ff0000ff0000ff;
    w = (w ^ (w >> 16)) & 0x00ff00000000ffff;
    w = (w ^ (w >> 32)) & 0x00000000001fffff;
    return (uint32_t)w;
};

static inline __device__ uint64_t Z_encode2(int3 pos)
{
    return ((spread((uint64_t)pos.x)) | (spread((uint64_t)pos.y) << 1) | (spread((uint64_t)pos.z) << 2));
};

static inline __device__ int3 Z_decode2(uint64_t Z_code)
{
    int3 res;
    res.x = compact(Z_code);
    res.y = compact(Z_code >> 1);
    res.z = compact(Z_code >> 2);
    return res;
};

static inline __device__ float power2(float base, int exp)
{
    float result = 1.f;
    while (exp != 0)
    {
        if ((exp & 1) == 1)
            result *= base;
        exp >>= 1;
        base *= base;
    }

    return result;
}

static inline __device__ float power(const float v, const int p)
{
    int n = p;
    if (n == 0)
        return 1.0f;

    if (n == 1)
        return v;
    float result = 1.0f;
    float a = v;
    while (n > 0)
    {
        if (n % 2 == 1)
            result *= a;
        a *= a;
        n /= 2;
    }

    return result;
}
