% Version: 1.1.0
% Reviser: Huayiwang
% Feature: Cluster Matched Filter
% Date: 2024-01-29

function result_CO2 = CTMF(swir_sel, k_CO2_sel, class_res)
s = k_CO2_sel;
%%
Niter = 15; % 迭代次数
K = max(class_res(:));
[row, col, wav] = size(swir_sel);
result_CO2 = {zeros(row, col)};
result_CO2 = repmat(result_CO2, Niter, 1);

%%
for j = 1 : K % 对类别索引
    mydisp(j)
    idx = find(class_res == j); % 第j类元素索引
    [idx_row, idx_col] = ind2sub([row, col], idx);
    NUM = length(idx_row);
    input = zeros(NUM, wav);
    for n = 1 : NUM
        input(n, :) = swir_sel(idx_row(n), idx_col(n), :);
    end
    [alpha, N_iter] = Filter(input, s, Niter);
    for i = 1 : N_iter
        for n = 1 : NUM
            result_CO2{i}(idx_row(n), idx_col(n)) = alpha{i}(n); % 迭代结果
        end
    end
    for i = N_iter + 1 : Niter
        for n = 1 : NUM
            result_CO2{i}(idx_row(n), idx_col(n)) = alpha{N_iter}(n); % 迭代结果
        end
    end
end
%%
result = result_CO2{7};
end
%%
function [alpha, N_iter] = Filter(L, s, Niter)
%% 输出参数
% L：行数 * 波段数
% s：波段数 * 1
% Niter: 迭代次数
%%
mu = cell(Niter, 1);
C = cell(Niter, 1);
alpha = cell(Niter, 1);
w = cell(Niter, 1);
%% 初始化

N = length(L); % 行数
mu{1} = mean(L, 'omitnan')'; % 波段数 * 1
C{1} = cov(L, 'omitrows'); % 波段数 * 波段数
alpha{1} = zeros(N, 1);
%% Albedo Factor
tempA = mu{1} .* s; % 8 * 1
tempC = C{1} \ tempA; % 8 * 1
r = (L * mu{1}) / (mu{1}' * mu{1}); % 2056 * 1
tempB = mu{1} - L'; % 8 * 2056
tempD = tempA' * tempC;
alpha{1} = tempB' * tempC ./ (r * tempD);
alpha{1}(alpha{1} < 0) = 0;
RES = sqrt(tempD);
%%
IsWeight = true; % 修改该参数
for k = 2 : Niter
    if(IsWeight) % 是否需要权重平衡
        epsilon = 1e-10;
        w{k} = 1 ./ (alpha{k - 1} * 2 + epsilon); % 2056 * 1
    else
        w{k} = zeros(N, 1);
    end
    mu{k} = mean(L - r .* alpha{k - 1} * mu{k - 1}' .* s', 'omitnan')';
    tempA = mu{k} .* s; % 8 * 1
    LC = L - r .* alpha{k - 1} * tempA'; % 2056 * 8
    C{k} = cov(LC, 'omitrows'); % 8 * 8
    tempC = C{k} \ tempA; % 8 * 1
    alpha{k} = zeros(N, 1);
    tempB = mu{k} - L'; % 8 * 2056
    tempD = tempA' * tempC;
    alpha{k} = (tempB' * tempC - w{k}) ./ (r * tempD);
    alpha{k}(alpha{k} < 0) = 0;
    if(abs(sqrt(tempD) - RES) < 1e-5)
        mydisp(sprintf('%02d迭代后已经收敛', k))
        N_iter = k;
        return;
    elseif(k == Niter)
        mydisp(sprintf('%02d迭代内未能收敛', Niter))
        N_iter = k;
    end
    RES = sqrt(tempD);
    mydisp(sprintf('%02d: %f',k, RES));
end
end
