classdef GeneratorRGB
    properties
        wavelengths % 波长
        transform_matrix = [3.1338561, -1.6168667, -0.4906146;
                            -0.9787684, 1.9161415, 0.0334540;
                             0.0719453, -0.2289914, 1.4052427]; % 转换矩阵
    end
    
    methods
        function obj = GeneratorRGB(wavelength)
            % 构造函数，接收波长数据
            obj.wavelengths = wavelength;
            % 默认的转换矩阵
            obj.transform_matrix = [3.1338561, -1.6168667, -0.4906146;
                                    -0.9787684, 1.9161415, 0.0334540;
                                    0.0719453, -0.2289914, 1.4052427];
        end
        % 计算piecewise高斯函数
        function y = piecewise_gaussian(~, x, u, sigma1, sigma2)
            y = (x < u) .* exp(-0.5 * (x - u).^2 / sigma1^2) + (x >= u) .* exp(-0.5 * (x - u).^2 / sigma2^2);
        end
        
        % 计算x通道高斯函数
        function result = x_gaussian(obj, wavelength)
            adj = struct('u', {0, 0, 0}, 'sigma1', {0, 0, 0}, 'sigma2', {0, 0, 0});
            result = 0.362 * obj.piecewise_gaussian(wavelength, 442.0 + adj(1).u, 16.0 + adj(1).sigma1, 26.7 + adj(1).sigma2);
            result = result - 0.065 * obj.piecewise_gaussian(wavelength, 501.1 + adj(2).u, 20.4 + adj(2).sigma1, 26.2 + adj(2).sigma2);
            result = result + 1.056 * obj.piecewise_gaussian(wavelength, 599.8 + adj(3).u, 37.9 + adj(3).sigma1, 31.0 + adj(3).sigma2);
        end
        
        % 计算y通道高斯函数
        function result = y_gaussian(obj, wavelength)
            adj = struct('u', {0, 0}, 'sigma1', {0, 0}, 'sigma2', {0, 0});
            result = 0.286 * obj.piecewise_gaussian(wavelength, 530.9 + adj(1).u, 16.3 + adj(1).sigma1, 31.1 + adj(1).sigma2);
            result = result + 0.821 * obj.piecewise_gaussian(wavelength, 568.8 + adj(2).u, 46.9 + adj(2).sigma1, 40.5 + adj(2).sigma2);
        end
        
        % 计算z通道高斯函数
        function result = z_gaussian(obj, wavelength)
            adj = struct('u', {0, 0}, 'sigma1', {0, 0}, 'sigma2', {0, 0});
            result = 0.980 * obj.piecewise_gaussian(wavelength, 437.0 + adj(1).u, 11.8 + adj(1).sigma1, 36.0 + adj(1).sigma2);
            result = result + 0.681 * obj.piecewise_gaussian(wavelength, 459.0 + adj(2).u, 26.0 + adj(2).sigma1, 13.8 + adj(2).sigma2);
        end
        
        % 计算CIE XYZ色匹配函数
        function [cie_x, cie_y, cie_z] = get_CIE_XYZ_weights(obj)
            cie_x = obj.x_gaussian(obj.wavelengths);
            cie_y = obj.y_gaussian(obj.wavelengths);
            cie_z = obj.z_gaussian(obj.wavelengths);
        end
        
        % 显示CIE XYZ色匹配函数曲线
        function show_weights_curve(obj)

            [cie_x, cie_y, cie_z] = obj.get_CIE_XYZ_weights();
            figure('Name', 'xyz_cmf');
            plot(obj.wavelengths, cie_x, 'r', 'LineWidth', 1.5);
            hold on;
            plot(obj.wavelengths, cie_y, 'g', 'LineWidth', 1.5);
            plot(obj.wavelengths, cie_z, 'b', 'LineWidth', 1.5);
            xlabel('wavelength/nm');
            ylabel('value');
            legend('x', 'y', 'z');
            grid on;
            hold off;
        end
        
        % 伽马校正
        function rgb = gamma_correction(~, rgb, gamma)
            if nargin < 3
                gamma = 1.5;
            end
            rgb = rgb .^ (1 / gamma);
        end
        
        % 将高光谱图像合成为RGB图像
        function rgb_img = get_rgb(obj, hyper_image, gamma_correction)
            if nargin < 3
                gamma_correction = false;
            end
            [cie_x, cie_y, cie_z] = obj.get_CIE_XYZ_weights();
            cie_x = repmat(shiftdim(cie_x,-2),size(hyper_image,1),size(hyper_image,2), 1);
            cie_y = repmat(shiftdim(cie_y,-2),size(hyper_image,1),size(hyper_image,2), 1);
            cie_z = repmat(shiftdim(cie_z,-2),size(hyper_image,1),size(hyper_image,2), 1);
            X = sum(hyper_image .* cie_x, 3);
            Y = sum(hyper_image .* cie_y, 3);
            Z = sum(hyper_image .* cie_z, 3);
            R = X * obj.transform_matrix(1, 1) + Y * obj.transform_matrix(1, 2) + Z * obj.transform_matrix(1, 3);
            G = X * obj.transform_matrix(2, 1) + Y * obj.transform_matrix(2, 2) + Z * obj.transform_matrix(2, 3);
            B = X * obj.transform_matrix(3, 1) + Y * obj.transform_matrix(3, 2) + Z * obj.transform_matrix(3, 3);
            rgb_img = cat(3, R, G, B);
            rgb_img = (rgb_img - min(rgb_img(:))) / (max(rgb_img(:)) - min(rgb_img(:)));
            if gamma_correction
                rgb_img = obj.gamma_correction(rgb_img);
            end
        end
    end
end
