%% Load in complex phase difference

clear all; clc;


pd_cplx = load('Testing_dataset/2000x_2mm_1y_2mm_256BIP_9V-phase-diff-le-single_SNRFIL227_All.mat');
%pd_cplx.vector_data = (pd_cplx.vector_data(:,4,:));
figure(1);
imagesc(squeeze(angle(pd_cplx.vector_data(:,4,:)))');
colormap(jet)
caxis([-pi pi])
axis image

andrea = 0;
if andrea
    C_ph = angle(pd_cplx.vector_data);
    B_ph1  =  (C_ph(:,1,:));
    B_ph2  = squeeze(C_ph(:,2,:))';
    B_ph3  = squeeze(C_ph(:,3,:));
    B_ph4  = squeeze(C_ph(:,4,:));
    pd_cplx.vector_data = B_ph1;
end
%% Initialization for weighted phase unwrapping
Pd.voxel_size_metres.x = 1e-06;
Pd.voxel_size_metres.y = 1e-06;
Pd.voxel_size_metres.z = 3.45e-06;
Pd.cplx_phase_diff_xyz = pd_cplx.vector_data;


out_dir = 'Testing_dataset';
out_name_oce  = fullfile(out_dir, ['WPU_unwrapped_phase_difference_sajat', '.mat']);


unwrap_config.PHASE_UNWRAP_Z_RANGE  = 7;
unwrap_config.PHASE_UNWRAP_X_RADIUS = 4;
unwrap_config.PHASE_UNWRAP_Y_RADIUS = 4;
unwrap_config.PHASE_UNWRAP_ALWAYS   = true;

out_pd_uw = volume_unwrap_phase_lateral( out_name_oce, Pd, unwrap_config);

if andrea
    Nz = size(pd_cplx.vector_data,3); Nx = size(pd_cplx.vector_data,1); zbar = linspace(0,1500e-6,Nz);
    x_scan_vec = linspace(-150,150,Nx).*1e-6; lambda = 1300e-9; B_ph1_xyz = pd_cplx.vector_data; %reshape(B_ph1,[Nx,1,Nz]);
    [strain_UWA,z_uwa,dphi_uw_tmp] = strain_UWA(lambda,zbar.',B_ph1_xyz(:,1,:),x_scan_vec.');
    dphi_uw = squeeze(dphi_uw_tmp)';
    dphi_uw2 = squeeze(out_pd_uw.pd_unwrapped_xyz(:,1,:) )';

    figure(12); clf;
    subplot(2,2,1); imagesc(dphi_uw);
    subplot(2,2,2); imagesc(dphi_uw2); colorbar;
    subplot(2,2,3); plot(dphi_uw(:,55)); hold on;plot(dphi_uw2(:,55));
    subplot(2,2,4); plot(dphi_uw2(:,55));
end

figure(11); clf;
imagesc(squeeze(out_pd_uw.pd_unwrapped_xyz(:,1,:))');
colormap(jet)
axis image



%% WLS


% Implement WLS strain estimation
S1 = abs(pd_cplx.vector_data) ./ sqrt(abs(pd_cplx.vector_data)); % weight of OCT SNR
phase1_flip = out_pd_uw.pd_unwrapped_xyz;


% Calculate strain from unwrapped pd
lambda = 1.3e-6; % central wavelength
fit_length = 100e-6; % physical distance in air
pix_size_z = 3.423e-6; % axial pixel size in air
strain_fit_pix = round( fit_length / pix_size_z );
D_i  = double( pix_size_z * (1:strain_fit_pix)' );
D_i2 = double( D_i .^ 2 );

OCT = (S1(:,:,:));
var_xz = 1./OCT(:,:,:);
Disp1 = (lambda/4/pi).*phase1_flip(:,:,:);
disp_xyz1 = Disp1;




%% Calculate the size of the output arrays and strain
out_size_y = size(disp_xyz1, 2);
out_size_x = size(disp_xyz1, 1);
out_size_z = size(disp_xyz1, 3) - strain_fit_pix + 1;
WLSLocalStrain1 = zeros(out_size_x, out_size_y, out_size_z);
disp_xz1 = zeros(out_size_x,512);

parfor i = 1:out_size_y
    for LatInc = 1:out_size_x
        %% For each fit segment
        % Have some extra accumulation variables to make parfor happy
        disp_xz1 = squeeze(disp_xyz1(:,i,:));

        WLS_Ascan1       = zeros(out_size_z, 1);
        WLSVarEst_Ascan1 = zeros(out_size_z, 1);

        disp_Ascan1 = double( disp_xz1(LatInc,:) );
        disp_var_Ascan1 = double( var_xz(LatInc,i,:) ); % every B-scan
        for PixInc = 1:out_size_z
            % column vector of displacements that make up the fit segment
            N_disp1 = disp_Ascan1( PixInc:PixInc+(strain_fit_pix-1) );
            % variance of the displacements that make up the fit segment
            N_DispVar1 = squeeze(disp_var_Ascan1( PixInc:PixInc+(strain_fit_pix-1)))';

            % sum of the displacement variances
            k0_wls1 = sum(1./N_DispVar1);
            k1_wls1 = sum(D_i'./N_DispVar1);
            k2_wls1 = sum(D_i2'./N_DispVar1);

            %% From Kennedy 2012b:
            % WLSnumerator: k0_wls * sum(weight .* D_i .* N_disp) - k1_wls *
            %               sum(weight .* N_disp)
            % weight = 1 ./ N_DispVar;
            % WLScorrect = k0_wls * sum(weight .* D_i .* N_disp) - k1_wls * ...
            %     sum(weight .* N_disp);

            %% Simplification from Sze Howe's Honours Thesis (Koh2011)
            % units of 1 / length ^ 2
            WLSnumerator1 = sum(((k0_wls1.*D_i' - k1_wls1) ./ N_DispVar1) .* N_disp1 );

            % Converted to length units of (Kennedy 2012b) "Strain estimation in
            % phase-sensitive optical coherence elastography"

            %% From Kennedy 2012b:
            % Units of 1 / length ^ 2
            % (\sum w_j) (\sum w_j (z_j - z_{i-1})^2) - (\sum w_j (z_j - z_{i-1}))^2
            WLSdenominator1 = k0_wls1 * k2_wls1 - k1_wls1^2;

            % Unitless
            WLS_Ascan1(PixInc) = WLSnumerator1 / WLSdenominator1;

            %% From Press2007NRTASC, section 15.2
            % Estimate the variance of the WLS estimate
            WLSVarEst_Ascan1(PixInc) = k0_wls1 / WLSdenominator1;
        end
        WLSLocalStrain1(LatInc,i,:) = WLS_Ascan1;
        %         WLSVarEst(LatInc,:)      = WLSVarEst_Ascan;
    end
    fprintf('output B-scan %d\n',i);

end

figure(2); clf;
imagesc(squeeze(1000 .* WLSLocalStrain1(:,1,:))'); colorbar;
colormap(hot)
caxis([-3 3])
axis image


