www.pudn.com > ConstrainedEM.zip > maximize.m


%last modified by tomer and aharon 24/4/03 to add handling of weighted distribution on data points.

% last modified tomer: 29.9 added diagonal_covmat_flag.
% last modified noam 23.6 exit_flag=3 if all models were removed

% k                    : the number of componenets.
% data                 : the data - every row is a data point.
% h_rep_inds			  : indexes into the probabilities array of representatives of 
%								 the hiddens.
% probabilities        : matrix of k*length(data) of the probabilities of observables.
%                        column i contain pdf of the hidden of observable i.
% single_cov_mat_flag       : if 1 - a single covariance matrix is used . else - k matrixes.
% calc_grad_num 			: number of full gradient calculations
%
% diagonal_covmat_flag : if 1 - only estimate diagonal values (variances) in cov matrix.
%                        if 0 (defualt) - estimate full cov matrix.
% w                    : weights on data points. if empty will disregard.

function [newParams]= maximize(data,k,param,Probabilities,oth,single_cov_mat_flag,...
    calc_grad_num,fixed_covmat,remove_models_flag,diagonal_covmat_flag)

global mm_flag;
global late_oracle;
global exit_flags;
global frozen_models;
global TOTAL_COV; %FOR ANDRE SPEECH STUFF

if ~exist('fixed_covmat') % when fixed covmat is set , the comats aren't updated. param(:,2) is used without change'. 
    fixed_covmat=0;
end

if ~exist('diagonal_covmat_flag') % when set, will only estimate diagonal values in cov mat
    diagonal_covmat_flag=0;
end

w=[];


% d is the dimension of each data point, n is the number of data points.
[n,d]=size(data);

numberOfObservables= n; %number of observable data points
%in the data set.

obsProbabilities=Probabilities(oth,:);	

%modify to handle weights: W is row vector
if(~isempty(w))
    normalizer= w*obsProbabilities;
else
    normalizer= sum(obsProbabilities);
end

newParams{1,2}= zeros(d,d);%init just in case single cov mat flag.

% approximate the new weights ( initial condition for a gradint ascent )
if late_oracle
    %if(~iscell(param)) 
    
    %should be modified to handle weights on data points. alphas should be weighted by W to get normalizer.
    
    naiveWeights=normalizer;		% weights without considering partition function
    %else
    %naiveWeights=cell2mat(param(:,3));	% weights of previous round
    %end
    
else
    naiveWeights= sum(Probabilities);	% Noam's old model : these are hidden weights, not observable weights 
end


verySmall= [];

%when diag cov mat flag is on we need only 2 points for estimation
%(not a good one!)
if(diagonal_covmat_flag == 0)
    % check for model degradation
    if ~single_cov_mat_flag
        if (mm_flag) 
            verySmall=find(normalizer0
            obsProbabilities=obsProbabilities(:,rem);
            Probabilities=Probabilities(:,rem);
            naiveWeights=naiveWeights(rem);
            normalizer=normalizer(rem);
            disp(sprintf('Model %d has been removed\n',verySmall));
        else
            exit_flags=bitor(exit_flags,4);
            newParams=[];
            return;
        end
    else	% freeze the small models ( dont change their cov mat )
        if ~all(ismember(verySmall,frozen_models))
            disp(sprintf('Model %d has been freezed\n',setdiff(verySmall,frozen_models)));
            exit_flags=bitor(exit_flags,32);
            frozen_models=union(frozen_models,verySmall);
        end
    end
end


%aharon says: this should work fine!
[newWeights]= approximateWeights(naiveWeights,calc_grad_num);

% the new gaussian params

for j=1:k
    if remove_models_flag==0 & ismember(j,verySmall)	% check if model should be freezed
        
        newParams{j,3}=newWeights(j);	% the model's weight isn't freezed
        newParams{j,1}= obsProbabilities(:,j)'*data/normalizer(j); % the model's mean isn't freezed.
        newParams{j,2}=param{j,2};		% the model's covmat is freezed ( same as the old one )
        
    else		% set all the model's new parameters :
        
        newParams{j,3}=newWeights(j);
        
        % new centers - regular EM update:
        
        newParams{j,1}= obsProbabilities(:,j)'*data/normalizer(j);
        
        
        if ~fixed_covmat      % new cov matrices
            % substract new mean from each observable data point
            centeredData= data - ones(n,1)*newParams{j,1};
            
            probMat= (ones(d,1)*(obsProbabilities(:,j)'));
                       
            covMat= ( probMat.*centeredData')*centeredData;
            % check if multiple covariance matrix assesment
            if ( single_cov_mat_flag==0)
                  newParams{j,2}= covMat./normalizer(j);
                if( diagonal_covmat_flag == 1 ) %remove all off diagonal elements
                    newParams{j,2}= newParams{j,2}.*eye(size(covMat));
                end
            else
                %single cov mat:
                newParams{1,2}= newParams{1,2} + covMat; 
            end
        else
            if(single_cov_mat_flag==0)
                newParams{j,2}= param{j,2};
            else
                newParams{1,2}= param{1,2};
            end
        end
    end
end

%check if need to normalize single cov matrix.
if( (single_cov_mat_flag==1)&(~fixed_covmat) )
    newParams{1,2}= newParams{1,2}./n; %noramlize by total number of
    %data points.
end

%check if need to remove off diagonal elements in diagonal covmat case:
if(diagonal_covmat_flag==1)
    if(~single_cov_mat_flag)
        for i=1:k
            newParams{i,2}= newParams{i,2}.*eye(size(covMat));
            
            % basic suggestion
            epsilon= 3e-5; 
            %newParams{i,2}= newParams{i,2}+ epsilon.*(eye(size(covMat)));
            newParams{i,2}= newParams{i,2}+ epsilon.*TOTAL_COV;
            %end
        end
    else
        if(~fixed_covmat)
            newParams{1,2}= newParams{1,2}.*eye(size(covMat));
            epsilon= 3e-5; 
            newParams{1,2}= newParams{1,2}+ epsilon.*TOTAL_COV;
        end
    end
end


if fixed_covmat
    for i=1:k
        newParams{i,2}=param{i,2};
    end
end