#=
Worked examples for the random-effects-on-durations paper.

Each `let` block produces one PDF figure. Run with:

    julia --project=. example.jl

Required packages: Pumas, AlgebraOfGraphics, CairoMakie, DataFrames,
ForwardDiff, Optim, Random.

The Pumas API for `randeffs_initialization` / `RandeffsInitialization.Center`
(referenced in the practical-recommendations of the paper) is from Pumas 2.8.
=#

using Pumas, CairoMakie, DataFrames, ForwardDiff, Optim, Random

const FIGDIR = joinpath(@__DIR__, "figures")
isdir(FIGDIR) || mkpath(FIGDIR)

# ---------------------------------------------------------------
# Model — IV infusion with LogNormal BSV on the infusion duration
# ---------------------------------------------------------------

mdl = @model begin
    @param begin
        θCL  ∈ RealDomain(lower = 0.0)
        θVc  ∈ RealDomain(lower = 0.0)
        θdur ∈ RealDomain(lower = 0.0)
        ωdur ∈ RealDomain(lower = 0.0)
        σ    ∈ RealDomain(lower = 0.0)
    end

    @random begin
        ηdur ~ LogNormal(log(θdur), ωdur)
    end

    @pre begin
        CL = θCL
        Vc = θVc
    end

    @dosecontrol begin
        duration = (; Central = ηdur)
    end

    @dynamics Central1

    @derived begin
        μ := Central ./ Vc
        y ~ @. Normal(μ, σ)
    end
end

DURPARS = (;
    θCL  = 1.0,
    θVc  = 1.5,
    θdur = 0.4,
    ωdur = 0.5,
    σ    = 1.0,
)

# ---------------------------------------------------------------
# Figure 1 — single-infusion NPLL(η): non-differentiability
# ---------------------------------------------------------------

let
    rng = Random.default_rng()
    Random.seed!(rng, 128)
    dr = DosageRegimen(100; rate = -2)
    times = [0.5, 1.0, 2.0, 4.0]
    simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS, rng = rng)
    sub = Subject(simsub)

    ηgrid = range(0.4, 0.7, length = 400)
    npll  = [Pumas.penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))
             for η in ηgrid]

    fig = Figure(size = (640, 420))
    ax = Axis(fig[1, 1], xlabel = "ηdur", ylabel = "NPLL(ηdur)")
    lines!(ax, ηgrid, npll, color = :black, linewidth = 1.5)
    vlines!(ax, [0.5], linestyle = :dash, color = (:gray, 0.7), linewidth = 1)
    save(joinpath(FIGDIR, "fig1-kink.pdf"), fig)
    @info "wrote fig1-kink.pdf"
end

# ---------------------------------------------------------------
# Figure 2 — multi-infusion NPLL(η) with the inner mode landing on a kink
# ---------------------------------------------------------------
#
# A specific simulation seed (found by example2.jl) where the global
# minimum of NPLL(η) lands exactly on one of the (observation, dose)
# kinks. Top panel shows NPLL with the kinked minimum highlighted;
# bottom panel shows the forward-difference derivative jumping
# discontinuously at the same kink.

let
    seed       = 71
    times      = collect(0.5:0.5:11.5)
    dose_times = collect(0.0:2.0:10.0)
    dr         = DosageRegimen(100; ii = 2, addl = 5, rate = -2)

    rng = Random.default_rng()
    Random.seed!(rng, seed)
    simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS,
                    rng = rng)
    sub = Subject(simsub)

    ηrange = (0.05, 1.5)
    ηgrid  = range(ηrange[1], ηrange[2], length = 4000)
    npll   = [penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))
              for η in ηgrid]

    # Kink locations within ηrange
    kinks = sort(unique(round.([t - d for t in times, d in dose_times
                                if ηrange[1] <= t - d <= ηrange[2]];
                               digits = 6)))

    idx_min = argmin(npll)
    η_min   = ηgrid[idx_min]
    f_min   = npll[idx_min]

    d_kink, k_idx = findmin(abs.(kinks .- η_min))
    h = step(ηgrid)
    sL = (npll[idx_min] - npll[max(idx_min - 1, 1)])           / h
    sR = (npll[min(idx_min + 1, length(npll))] - npll[idx_min]) / h
    @info "η_min ≈ $(round(η_min; digits=4)), nearest kink ≈ $(round(kinks[k_idx]; digits=4)) (Δ ≈ $(round(d_kink; digits=5)))"
    @info "one-sided slopes: left $(round(sL; digits=2)), right $(round(sR; digits=2))"

    # Forward-difference derivative of NPLL. We replace the cell whose
    # ηgrid interval straddles each kink with NaN so the rendered line
    # actually breaks at the discontinuity instead of connecting across it.
    dnpll = collect(diff(npll) ./ h)
    ηmid  = collect((ηgrid[1:end-1] .+ ηgrid[2:end]) ./ 2)
    for k in kinks
        j = searchsortedfirst(ηgrid, k)
        if 2 <= j <= length(ηgrid)
            dnpll[j-1] = NaN
        end
    end

    fig = Figure(size = (900, 900))

    ax1 = Axis(fig[1, 1], ylabel = "NPLL(ηdur)",
               title = "Inner minimum at a kink")
    ax2 = Axis(fig[2, 1], xlabel = "ηdur", ylabel = "d NPLL / d ηdur")
    rowsize!(fig.layout, 1, Auto(1))
    rowsize!(fig.layout, 2, Auto(1))
    linkxaxes!(ax1, ax2)
    hidexdecorations!(ax1, ticks = false, grid = false)

    for ax in (ax1, ax2)
        for k in kinks
            vlines!(ax, [k], color = (:gray, 0.3),
                    linestyle = :dot, linewidth = 0.7)
        end
    end
    hlines!(ax2, [0.0], color = (:gray, 0.5), linewidth = 0.6)

    lines!(ax1, ηgrid, npll, color = :black, linewidth = 1.5)
    scatter!(ax1, [η_min], [f_min], color = :crimson, markersize = 11)

    lines!(ax2, ηmid, dnpll, color = :black, linewidth = 1.5)

    save(joinpath(FIGDIR, "fig2-kinked-minimum.pdf"), fig)
    @info "wrote fig2-kinked-minimum.pdf"
end

# ---------------------------------------------------------------
# Figure 3 — multi-infusion NPLL(η): multiple local minima
# ---------------------------------------------------------------
#
# Six infusions every 2 h, dense sampling every 0.5 h. With many overlapping
# (obs, dose) pairs the absorption curve's non-monotonicity in η creates
# many secondary basins per observation; if enough align in η, NPLL
# carries multiple local minima even at the simulation-truth θ.

let
    rng = Random.default_rng()
    Random.seed!(rng, 128)
    dr = DosageRegimen(100; ii = 2, addl = 5, rate = -2)         # doses at 0, 2, …, 10
    times = collect(0.5:0.5:11.5)                     # 23 obs every 0.5 h
    simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS, rng = rng)
    sub = Subject(simsub)

    ηgrid = range(0.1, 1.5, length = 1200)
    npll  = [Pumas.penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))
             for η in ηgrid]

    is_local_min = falses(length(npll))
    for i in 2:length(npll)-1
        is_local_min[i] = npll[i] < npll[i-1] && npll[i] < npll[i+1]
    end
    minη = ηgrid[is_local_min]
    miny = npll[is_local_min]

    fig = Figure(size = (640, 420))
    ax = Axis(fig[1, 1], xlabel = "ηdur", ylabel = "NPLL(ηdur)")
    lines!(ax, ηgrid, npll, color = :black, linewidth = 1.5)
    scatter!(ax, minη, miny;
             marker = :circle, markersize = 9,
             color = :transparent,
             strokecolor = :crimson, strokewidth = 1.5,
             label = "Local minima")
    axislegend(ax, position = :lt)
    save(joinpath(FIGDIR, "fig3-multimodal.pdf"), fig)
    @info "wrote fig3-multimodal.pdf — found $(length(minη)) local minima"
end

# ---------------------------------------------------------------
# Figure 4 — multi-dose NPLL with one Laplace quadratic per local mode
# ---------------------------------------------------------------
#
# Same data as Figure 2. For each local minimum of NPLL we compute a
# Laplace quadratic — a parabola centred at the mode with curvature
# equal to the local Hessian. None of them captures the multimodal
# structure: each fits its own basin and diverges past the nearest kink.

let
    rng = Random.default_rng()
    Random.seed!(rng, 128)
    dr = DosageRegimen(100; ii = 2, addl = 5, rate = -2)
    times = collect(0.5:0.5:11.5)
    simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS, rng = rng)
    sub = Subject(simsub)

    nll(η) = Pumas.penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))

    # Locate local minima via a dense grid scan, then refine each with Brent.
    ηgrid_scan = range(0.1, 1.5, length = 2400)
    npll_scan  = nll.(ηgrid_scan)
    is_min = falses(length(npll_scan))
    for i in 2:length(npll_scan)-1
        is_min[i] = npll_scan[i] < npll_scan[i-1] && npll_scan[i] < npll_scan[i+1]
    end

    h_scan = step(ηgrid_scan)
    modes = NamedTuple{(:η̂, :f̂, :H), Tuple{Float64,Float64,Float64}}[]
    for i in findall(is_min)
        η0  = ηgrid_scan[i]
        res = optimize(nll, η0 - 3h_scan, η0 + 3h_scan, Optim.Brent())
        η̂   = Optim.minimizer(res)
        f̂   = Optim.minimum(res)
        H   = ForwardDiff.derivative(x -> ForwardDiff.derivative(nll, x), η̂)
        push!(modes, (; η̂ = η̂, f̂ = f̂, H = H))
    end

    ηgrid = range(0.1, 1.5, length = 1200)
    npll  = nll.(ηgrid)

    fig = Figure(size = (640, 420))
    ax = Axis(fig[1, 1], xlabel = "ηdur", ylabel = "NPLL(ηdur)")
    ylims!(ax, -100, 1.1 * maximum(npll))
    lines!(ax, ηgrid, npll, color = :black, linewidth = 1.5, label = "NPLL")
    pal = (:crimson, :royalblue, :forestgreen, :darkorange, :purple)
    for (i, m) in pairs(modes)
        quad = m.f̂ .+ 0.5 .* m.H .* (ηgrid .- m.η̂).^2
        lines!(ax, ηgrid, quad, color = pal[i], linewidth = 1.3,
               linestyle = :dash,
               label = "quadratic at η̂ = $(round(m.η̂; digits = 2))")
        scatter!(ax, [m.η̂], [m.f̂], color = pal[i], markersize = 9)
    end
    axislegend(ax, position = :lt)
    save(joinpath(FIGDIR, "fig4-laplace-overlay.pdf"), fig)
    @info "wrote fig4-laplace-overlay.pdf — modes: $modes"
end

# ---------------------------------------------------------------
# Figure 5 — NPLL(η) under increasing density of observations around the doses
# ---------------------------------------------------------------
#
# Three doses fixed; observations spread across [0.5, 4.0] h with
# increasing density. Each new observation lands close to a dose-time
# boundary, so each adds at least one kink to NPLL(η). With more obs
# the basin near the truth tightens but the curve accumulates more
# kinks rather than fewer.

let
    levels = (3,)          # number of doses
    palette = (:steelblue, :darkorange, :forestgreen, :firebrick, :black)

    fig = Figure(size = (640, 420))
    ax = Axis(fig[1, 1], xlabel = "ηdur", ylabel = "NPLL(η) − min NPLL")

    for (i, n_doses) in pairs(levels)
        for (j, n_obs)  = enumerate([4, 6, 8, 10, 30])

        rng = Random.default_rng()
        Random.seed!(rng, 128)
        if n_doses == 1
            dr    = DosageRegimen(100; rate = -2)
            times = collect(0.5 : 0.5 : 4.0)
        else
            dr     = DosageRegimen(100; ii = 2, addl = n_doses - 1, rate = -2)
            last_t = 2.0 * (n_doses - 1)
            @show times  = collect(0.5 : 3.5/n_obs : 4.0)
        end
        simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS, rng = rng)
        sub    = Subject(simsub)

        ηgrid = range(0.01, 2.5, length = 2000)
        npll  = [Pumas.penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))
                 for η in ηgrid]
        npll_shifted = npll .- minimum(npll)

        lines!(ax, ηgrid, npll_shifted, color = palette[j], linewidth = 1.4,
               label = "$(n_doses) doses, $(length(times)) obs")
    end
    end

    vlines!(ax, [DURPARS.θdur], color = (:gray, 0.6), linestyle = :dot, linewidth = 1)
    axislegend(ax, position = :lt)
    save(joinpath(FIGDIR, "fig5-data-growth-around-doses.pdf"), fig)
    @info "wrote fig5-data-growth-around-doses.pdf"
end

# ---------------------------------------------------------------
# Figure 6 — NPLL(η) under increasing density of trough/late observations
# ---------------------------------------------------------------
#
# Three doses fixed, with a sDURparse fixed sampling around dose 1 plus a
# growing number of trough/late samples in [4, 12] h. The trough
# observations are far from any dose-time boundary in η ∈ [0, 2.5], so
# they do not add new kinks within the relevant range — they only
# rescale the residual penalty. The kink locations are unchanged across
# all curves.

let
    levels = (3,)          # number of doses
    palette = (:steelblue, :darkorange, :forestgreen, :firebrick, :black)

    fig = Figure(size = (640, 420))
    ax = Axis(fig[1, 1], xlabel = "ηdur", ylabel = "NPLL(η) − min NPLL")

    for (i, n_doses) in pairs(levels)
        for (j, n_obs)  = enumerate([1, 3, 5, 7, 27])

        rng = Random.default_rng()
        Random.seed!(rng, 128)
        if n_doses == 1
            dr    = DosageRegimen(100; rate = -2)
            times = collect(0.5 : 0.5 : 4.0)
        else
            dr     = DosageRegimen(100; ii = 2, addl = n_doses - 1, rate = -2)
            last_t = 2.0 * (n_doses - 1)
            @show times  = vcat(collect(0.5 : 1.0 : 3.5), n_obs == 1 ? [4.0] : range(4.0, 12, n_obs))
        end
        simsub = simobs(mdl, Subject(events = dr, time = times), DURPARS, rng = rng)
        sub    = Subject(simsub)

        ηgrid = range(0.1, 2.5, length = 1200)
        npll  = [Pumas.penalized_conditional_nll(mdl, sub, DURPARS, (; ηdur = η))
                 for η in ηgrid]
        npll_shifted = npll .- minimum(npll)

        lines!(ax, ηgrid, npll_shifted, color = palette[j], linewidth = 1.4,
               label = "$(n_doses) doses, $(length(times)) obs")
    end
    end

    vlines!(ax, [DURPARS.θdur], color = (:gray, 0.6), linestyle = :dot, linewidth = 1)
    axislegend(ax, position = :lt)
    save(joinpath(FIGDIR, "fig6-data-growth-trough.pdf"), fig)
    @info "wrote fig6-data-growth-trough.pdf"
end
