%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function [x_predictions,y_predictions,window_times] =
%   make_kalman_predictions(experiment,filter,start_time,end_time)
%
% Generates a vector of neural predictions for the specified interval using the provided kalman filter
%
% experiment : an experiment structure in the form returned by plx_to_matlab.m
%
% filter : a filter structure in the form returned by make_kalman_filter.m
%
% start_time, end_time : the start and end times to decode, in seconds
%
% Returns column vectors containing the x and y predictions and the time corresponding to 
% the end of each bin referred to in the kinematic arrays
%
% Dan Morris, Stanford University, 2003
% http://techhouse.brown.edu/~dmorris
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [x_predictions,y_predictions,window_times,K,P] = make_kalman_predictions(experiment,filter,start_time,end_time)

if (nargin < 4) 
    fprintf('\nImproper number of arguments.\n');
    help make_kalman_predictions;
    return;
end

cell_list = filter.cell_list + 1;
bin_size = filter.bin_size;
num_bins = floor( (end_time - start_time) / bin_size);
num_cells = size(cell_list,1);
real_xpos = experiment.xpos;
real_ypos = experiment.ypos;

% Get a nicely formatted matrix of (cells,bins) (I don't use 'R' at all)
[unformatted_response,R] = make_response_matrix(num_cells,num_bins,experiment,cell_list,bin_size,start_time,0);

% Now I need the starting x,y,dx,dy

cur_x_index = 1;
cur_y_index = 1;

% find the indices in the 'real' x and y arrays corresponding to the first
% time point
while(cur_x_index < size(real_xpos,1) & real_xpos(cur_x_index,1) < start_time)
    cur_x_index = cur_x_index + 1;
end

while(cur_y_index < size(real_ypos,1) & real_ypos(cur_y_index,1) < start_time)
    cur_y_index = cur_y_index + 1;
end

% Use the real initial values for position and velocity
x = real_xpos(cur_x_index,2);
y = real_ypos(cur_x_index,2);
dx = (real_xpos(cur_x_index+1,2) - real_xpos(cur_x_index,2)) / ...
    (real_xpos(cur_x_index+1,1) - real_xpos(cur_x_index,1));
dy = (real_ypos(cur_y_index+1,2) - real_ypos(cur_y_index,2)) / ...
    (real_ypos(cur_y_index+1,1) - real_ypos(cur_y_index,1));

state = [(x - filter.center(1)) (y - filter.center(2)) dx dy]';
state_m = state;

x_predictions = zeros(num_bins,1);
y_predictions = zeros(num_bins,1);
window_times = zeros(num_bins,1);

x_predictions(1) = x;
y_predictions(1) = y;

% Initial values for state variables
P_m = zeros(4,4,num_bins);
P = zeros(4,4,num_bins);
K = zeros(4,num_cells,num_bins);

A = filter.A;
W = filter.W;
H = filter.H;
Q = filter.Q;

step = 2;

fprintf(1,'Generating predictions...\n');

cur_time = start_time;
window_times(1) = start_time + filter.lag*bin_size;

while(step <= num_bins)

    % prior estimation 
    P_m(:,:,step) = A*P(:,:,step-1)*A'+W; 
    state_m(:,step) = A*state(:,step-1); 
   
    z = unformatted_response(:,step);
    
    % posterior estimation
    K(:,:,step) = P_m(:,:,step)*H'*inv(H*P_m(:,:,step)*H'+Q);       
    P(:,:,step) = (eye(4)-K(:,:,step)*H)*P_m(:,:,step);   
    state(:,step) = state_m(:,step)+K(:,:,step)*(z-H*state_m(:,step));  
    
    cur_time = cur_time + bin_size;
    window_times(step) = cur_time;
    
    % Here we account for the fact that the filter was built to imply some lag between kinematics
    % and spikes...
    window_times(step) = window_times(step) + filter.lag*bin_size;

    if (mod(step,100) == 0)
        fprintf(1,'.');
    end
    if (mod(step,60*100) == 0)
        fprintf(1,'\n');
    end
    
    x = state(1,step)+filter.center(1); % make X prediction
    y = state(2,step)+filter.center(2); % make Y prediction
    
    x_predictions(step) = x;
    y_predictions(step) = y;
    step = step + 1;
end

