function swir_new = calculate_weighted_swir(swir_sel, class_res, K)
    % 初始化变量以保存每个类别的 "swir_sel" 均值
    swir_sel_means = zeros(K, size(swir_sel, 3));
    % 初始化变量以保存每个类别的像素数量
    num_pixels_per_class = zeros(1, K);
    % 循环遍历每个类别
    for m = 1:K
        % 找出属于当前类别的像素的索引
        class_pixels_indices = find(class_res == m);
        % 计算当前类别中的像素数量
        num_pixels_per_class(m) = numel(class_pixels_indices);
        % 创建一个与原图像大小相同的逻辑数组
        binary_mask = false(size(swir_sel, 1), size(swir_sel, 2));
        binary_mask(class_pixels_indices) = true;
        % 提取这些像素的 "swir_sel" 值
        swir_sel_values = swir_sel .* repmat(binary_mask, [1, 1, size(swir_sel, 3)]);
        swir_sel_values(swir_sel_values == 0) = NaN;
        % 计算 "swir_sel" 在第三维上的均值
        swir_sel_means(m, :) = squeeze(mean(mean(swir_sel_values, 1, 'omitnan'), 2, 'omitnan'));
    end
    
    % 计算权重
    weights = zeros(size(swir_sel_means));
    A_mean = mean(mean(swir_sel, 1), 2);
    for n = 1:size(swir_sel, 3)
        weights(:, n) = A_mean(n) ./ swir_sel_means(:, n);
    end
    
    % 计算加权后的 SWIR
    swir_new = zeros(size(swir_sel));
    for n = 1:size(swir_sel, 3)
        swir_temp = swir_sel(:,:,n);
        swir_temp2 = zeros(size(swir_sel, [1 2]));
        for m = 1 : K
            class_pixels_indices = find(class_res == m);
            swir_temp2(class_pixels_indices) = swir_temp(class_pixels_indices) * weights(m, n);
        end
        swir_new(:,:,n) = swir_temp2;
    end
end
