function [res, dX, dY, current_ncc] = register_ncc_v4(src, trg, opt)
%%% this function is to implement non-rigid registration between the source
%%% image and the target image

if ~exist('opt','var') % Default parameters 
    opt.max_it = 400; %maximum number of iterations in one resolution
    opt.lambda = 50000;  % Step size of the velocity update
    opt.sigma_p = 1.5; % Size of Gaussian for smoothing deformation field
    opt.multiRes = 2;
else    % User inputed parameters
    max_it = opt.max_it;
    lambda = opt.lambda;  % Step size of the velocity update
    sigma_p = opt.sigma_p; % Size of Gaussian for smoothing deformation field
    multiRes = opt.multiRes;
end

% initialize parameters
src_original = src;
[M,N] = size(src);
[X_0,Y_0] = meshgrid(1:N,1:M);
Ux = X_0*0;
Uy = Y_0*0;
X_c = X_0;
Y_c = Y_0;

% blurr images a little
%hG = fspecial('gaussian',[9,9],sigma); %filter for blurring
hGp = fspecial('gaussian',[9,9],sigma_p); %filter for blurring penalty

%hG = hG./sum(hG(:));
hGp = hGp./sum(hGp(:)); % normalize the filter

%src = imfilter(src,hG);
%trg = imfilter(trg,hG);

for r = 0:1:multiRes

    rt = multiRes-r;
    src_t = imresize(src,(1/2)^(rt),'bilinear');
    trg_t = imresize(trg,(1/2)^(rt),'bilinear');
    
    %create linear grid
    [M,N] = size(src_t); % this is the grid for the smaller image
    [X_0,Y_0] = meshgrid(1:N,1:M);
    
    %upsample deformation
    if r==0
        Ux = imresize(Ux,[M,N],'bilinear')*1;%(r+1);
        Uy = imresize(Uy,[M,N],'bilinear')*1;%(r+1);
    else
        Ux = imresize(Ux,[M,N],'bilinear')*2;%(r+1);
        Uy = imresize(Uy,[M,N],'bilinear')*2;%(r+1);
    end
        
    t_norm_sqrd = sum(sum(trg_t.^2));
%     figure(11); imagesc(src_t); truesize; colormap gray;
%     figure(12); imagesc(trg_t); truesize; colormap gray;

    % compute normalized cc squared
    X_c = X_0 - Ux;
    Y_c = Y_0 - Uy;
    
    src_temp = interp2(src_t,X_c,Y_c);
    [src_temp] = NaNFix(src_temp);
    SfT = sum(sum(src_temp.*trg_t));
    s_norm_sqrd = sum(sum(src_temp.^2));
    current_ncc = (SfT.^2)/(s_norm_sqrd*t_norm_sqrd);
%     fprintf('the initial costing function value %f at the %d level\n',current_ncc,rt);
    [dIx,dIy] = gradient(src_temp);

    stop = 0;
    iterations = 0;
    
    while (stop < 1)
        iterations
        current_ncc
        iterations = iterations+1;
%         figure(11);imagesc(src_temp);truesize;colormap gray;

        % Compute first derivative
        dUx  =  (2/((s_norm_sqrd^2)*t_norm_sqrd))*SfT*dIx.*(SfT*src_temp - s_norm_sqrd*trg_t);
        dUy  =  (2/((s_norm_sqrd^2)*t_norm_sqrd))*SfT*dIy.*(SfT*src_temp - s_norm_sqrd*trg_t);

        % update deformation
        Uxt = Ux + lambda*dUx; %+ plus since it is maximization
        Uyt = Uy + lambda*dUy;
        
        Uxt = imfilter(Uxt,hGp,'same');
        Uyt = imfilter(Uyt,hGp,'same');
        
        Uxt(:,1) = 0;Uxt(:,N) = 0;Uxt(1,:) = 0;Uxt(M,:) = 0;
        Uyt(:,1) = 0;Uyt(:,N) = 0;Uyt(1,:) = 0;Uyt(M,:) = 0;
        
        X_ct = X_0 - Uxt;
        Y_ct = Y_0 - Uyt;

        % update image and compute derivative (just to test)
        src_tempt = interp2(src_t,X_ct,Y_ct,'linear');
        
%         figure; imshow(src_tempt,[]); impixelinfo;
%         pause;
%         close;
        
        [src_tempt] = NaNFix(src_tempt);
        SfTt = sum(sum(src_tempt.*trg_t));
        s_norm_sqrdt = sum(sum(src_tempt.^2));
        current_ncct = (SfTt.^2)/(s_norm_sqrdt*t_norm_sqrd);
        
%         fprintf('The cost function value %f at the iteration %d and resolution %d\n',...
%             current_ncct,iterations,rt);

        % check for improvement
        if (current_ncct > (current_ncc))
            X_c = X_ct;
            Y_c = Y_ct;
            Ux = Uxt;
            Uy = Uyt;
            src_temp = src_tempt;
            s_norm_sqrd = s_norm_sqrdt;
            current_ncc = current_ncct;
            SfT = sum(sum(src_temp.*trg_t));
            [dIx,dIy] = gradient(src_temp);

        else
            if ((iterations < 25) && (current_ncc-current_ncct < 0.01) )
                X_c = X_ct;
                Y_c = Y_ct;
                Ux = Uxt;
                Uy = Uyt;
                src_temp = src_tempt;
                s_norm_sqrd = s_norm_sqrdt;
                current_ncc = current_ncct;
                SfT = sum(sum(src_temp.*trg_t));
                [dIx,dIy] = gradient(src_temp);
            else
                stop = 1;
            end
        end

        if ( (iterations > max_it))
            stop = 1;
        end
    end %while loop
    
end%resolution loop
    
dX = X_c;
dY = Y_c;
res = interp2(src_original,X_c,Y_c);
% [res] = NaNFix(res);

% figure; imshow(res,[]); 
% hold on;
% plot(dX(:,1:10:size(res,2)),dY(:,1:10:size(res,2)),'g');
% plot(dX(1:10:size(res,1),:)',dY(1:10:size(res,1),:)','g');
% hold off;