classdef ColorTransformer
    properties
        M % 颜色变换矩阵 (3x3)
        b % 偏置向量 (3x1)
    end

    methods
        % 构造函数
        function obj = ColorTransformer(colorMap_AHSI, colorMap_Google)
            % colorMap_AHSI: AHSI颜色矩阵
            % colorMap_Google: Google颜色矩阵
            if nargin > 0
                % 计算颜色变换矩阵和偏置
                [obj.M, obj.b] = obj.calculateTransformation(colorMap_AHSI, colorMap_Google);
            end
        end

        % 计算颜色变换矩阵
        function [M, b] = calculateTransformation(~, colorMap_AHSI, colorMap_Google)
            % 验证输入
            assert(size(colorMap_AHSI, 1) == size(colorMap_Google, 1), '矩阵行数不匹配');
            assert(size(colorMap_AHSI, 2) == 3 && size(colorMap_Google, 2) == 3, '矩阵必须是RGB格式');

            % 构建方程组
            n = size(colorMap_AHSI, 1);
            A = [colorMap_AHSI, ones(n, 1)]; % 添加偏置项
            B = colorMap_Google;

            % 最小二乘解 [M | b]
            X = (A' * A) \ (A' * B);

            % 提取 M 和 b
            M = X(1:3, :)'; % 3x3 变换矩阵
            b = X(4, :)';   % 3x1 偏置向量
        end

        % 应用颜色变换
        function outputImage = transformImage(obj, inputImage)
            % inputImage: 输入图像 (MxNx3)
            % outputImage: 转换后的图像

            % 获取图像的行列数
            [rows, cols, ~] = size(inputImage);

            % 将输入图像转为 M x 3 矩阵，其中每行是一个像素的 RGB 值
            inputImage_reshaped = reshape(inputImage, rows * cols, 3);  % 将图像转为行向量形式

            % 应用颜色变换：使用矩阵乘法进行转换
            outputImage_reshaped = inputImage_reshaped * obj.M' + obj.b';  % obj.M' 是 3x3 矩阵，obj.b 是 1x3 向量

            % 确保像素值在 [0, 255] 范围内
            outputImage_reshaped = min(max(outputImage_reshaped, 0), 255);

            % 将转换后的像素值重新转换回图像的原始维度
            outputImage = reshape(outputImage_reshaped, rows, cols, 3);

            % 将输出图像数据类型转换为 uint8
            outputImage = uint8(outputImage);
        end

        % 保存模型到文件
        function saveModel(obj, filePath)
            % filePath: 保存的文件路径
            save(filePath, 'obj');
        end
    end

    % 静态方法：从文件加载模型
    methods (Static)
        function obj = loadModel(filePath)
            % filePath: 模型文件路径
            data = load(filePath, 'obj');
            obj = data.obj;
        end
    end

end
