#pragma once

#include <cmath>

#include <json.hpp>

#include "../../Commons/Physics.h"
#include "../../Commons/utilityKernel.cuh"

typedef struct
{
    float Rc;
    float k;
    float rcdlvo;
    float rrep;
    float coeffdir;
    int second;
    int first;
    float Hamaker;
    float pot;

} DLVO_t;

class DLVO
{
public:
    DLVO_t data;

    void init(nlohmann::json js, float temperature)
    {
        cudaMallocManaged((void **)&data, sizeof(DLVO_t));
        data.Rc = js["Rc"];
        data.k = js["k"];
        data.rcdlvo = js["rcdlvo"];
        data.rrep = js["rrep"];
        data.coeffdir = js["coeffdir"];
        data.second = js["second"];
        data.first = js["first"];
        data.Hamaker = js["Hamaker"];

        data.pot = 0;
    }

    inline __device__ double3 compute(const float r0, const float r1, const double dist, const double3 diff_p_pn, const float psii, const float psij)
    {
        const double k = data.k;
        const double A = data.Hamaker;
        const double ai = r0;
        const double aj = r1;
        const double rcdlvo = data.rcdlvo * (ai + aj);
        // const double rrep = data.rrep * (ai + aj);
        const double suma = ai + aj;
        const double suma2 = suma * suma;
        const double difa = ai - aj;
        const double difa2 = difa * difa;
        const double con3a1a2 = Constant::PI * Constant::water_dielectric_const * Constant::E0 * ai * aj / (ai + aj) * 2 * psii * psij;
        const double con3a1a2b = Constant::PI * Constant::water_dielectric_const * Constant::E0 * ai * aj / (ai + aj) * (psii * psii + psij * psij);
        if (dist > rcdlvo)
        {
            const double bb = A / 6 * ((4 * ai * aj * dist / (dist * dist - suma2) / (dist * dist - suma2)) + 4 * ai * aj * dist / (dist * dist - difa2) / (dist * dist - difa2) - 2 * dist * (suma2 - difa2) / (dist * dist - suma2) / (dist * dist - difa2));
            const double aa = (-con3a1a2 * 2 * k * exp(-k * (dist - ai - aj)) / ((1 - exp(-k * (dist - ai - aj))) * (1 + exp(-k * (dist - ai - aj))))) + con3a1a2b * 2 * k * exp(-2 * k * (dist - ai - aj)) / (1 - exp(-2 * k * (dist - ai - aj)));
            return -diff_p_pn * (bb + aa) / dist;
        }
        if (dist < rcdlvo)
        {
            return diff_p_pn * data.coeffdir / dist;
        }
        return make_double3(0, 0, 0);
    }
    ~DLVO()
    {
    }
};
