www.pudn.com > ConstrainedEM.zip > mixture_of_gaussians3.m



% This is the ConstrainedEM function. it gets the data, required number of
% models, chunklets and negative constraints and returns the model found.

% function parameters :

% MANDATORY PARAMETERS :

% data                      : the data - every row is a data point.

% k                         : the number of componenets.

% CONSTRAINTS PARAMETERS :

% chunks                    : a tag list containing the chunklet information (length =length(data) )
%                           the list containts -1 and integer tags 1-number
%                           of chunklets.
%                           if the i'th place contains -1 -  the point doesn't
%                           belong to any chunklet. if it contains the
%                           tag 'j' the point is in a chunklet with all the
%                           other points with the tag 'j'.

% anti _ chunks		        : a table(n,2) of pairs. each pair is known to
%                           be negatively constrained

% ADDITIONAL OPTIONAL PARAMETERS

% param                     :  a cell of kX3 start conditions : 
%                              the param(:,1) are centers . 
%                              the param(:,2) are cov matrixes. 
%                              the param(:,3) are the weights (scalars ) and they should sum to 1. 
%                           If param is empty or not given start params are
%                           randomized.
%                           If only param{1,2} isn't empty it is used as
%                           the start covariance matrix, while centers and
%                           weights are randomized.

% ClassLabels               If sent, the labels are used to asses the purity_acuracy score.
%                           the results are returned at z. See also
%                           the internal parameter z_record flag 

% single_cov_mat_flag       1 - a single covariance matrix is used . 
%                           0 - k matrixes.
%                           defulats to 1 ( single covariance matrix)

% late_oracle				: early (0) / late (1) oracle flag. (default 1)

% calc_grad_num		        : number of general pf gradient calculation ( default 0 )

% aa_flagP		            : aproximate anti chunklets pf flag (default 1 )

% reduce_arcs		        : controls pruning of anti chunklet arcs
%					0 - no pruning.
%					1 - prune contextually on every EM round
%					2 - prune once - only at the begining (default 2 )

% fixed_covmat              : if 0 (default) will estimate cov matrix,
%                             if 1 will not estimate cov matrix and will use the one inputed.

% small_model_policy        : when 0 , small models are removed and no grace turn is given
%                             when 1 , small models are freezed     
%                             when 2 , small models are removed and grace turn is given
%                           The default value is 2.

%returns:

% results                   : the best parameters achieved
% logLikelihood             : a log of the ll decrease.
% groups                    : an assignment vector containing in the place
%                           i the MAP assignment of data point i.

% z                         : When z_record_flag=0 ( the default ) returns
%                           a vector of 3 numbers :
%                           1. purity  2. accuarcy  3. z_score
%                           of the final model. 
%                           If z_record_flag=1 a developement log of the 
%                           ( purity, accuracy , z_score ) is returned, in the
%                           format of (number of iterations*3) table.

% exit_flags					 bit 1 : the last round decrease the ll
%									 bit 2 : the aa_flag was turned off in the midst of EM
%									 bit 3 : one model or more was removed
%									 bit 4 : catastroph : ll reached -infinity or all models have been deleted.
%									 bit 5 : a covmat was close to singular during the convergence.
%									 bit 6 : a model was frozen while EM


function [results,logLikelihood ,groups ,z , anti_chunks_pruned ,exit_flags ]=...
   mixture_of_gaussians3(data,k,chunks, anti_chunks,param, ClassLabels,...
   single_cov_mat_flag, late_oracleP,...
   calc_grad_num,aa_flagP,reduce_arcsP,fixed_covmat,small_model_policy)

% important parameters :

s=size(data);
d=s(2);     % the dimension.
n=s(1);     % the number of samples.

z_record_flag=0;		% when set , a purity-acuracy score record is returned , in which
				% the purity-acuracy score is calculated in every iteration.
				% when 0 , only a final score is calculated -as [purity,accuracy,score].

if ~exist('single_cov_mat_flag')
    single_cov_mat_flag=1;
end

if ~exist('ClassLabels')
    ClassLabels=[];
end
                
global late_oracle;
if ~exist('late_oracleP')
  late_oracle=1;		% early or late oracle. default is late oracle.
else
  late_oracle=late_oracleP;
end

% handle lack of chunkletts or anti-chunkletts
if ~exist('anti_chunks')
   anti_chunks=[];
end

if ~exist('chunks') | (isempty(chunks))
   chunks=-ones(n,1);
end

if ~exist('calc_gard_num')
   calc_grad_num=0;		% number of times a gradient is calculated per iteration.
end

global aa_flag;	% shared with calculate_partition_function2
if ~exist('aa_flagP')
   aa_flag=1;
else
   aa_flag=aa_flagP;
end
prev_aa_flag= aa_flag;

global reduce_arcs;
if ~exist('reduce_arcsP')
   reduce_arcs=2;
else 
   reduce_arcs=reduce_arcsP;
end

if ~exist('fixed_covmat') % when fixed covmat is set , the covmats aren't updated. param(:,2) is used without change. 
    fixed_covmat=0;
end

if ~exist('small_model_policy')
   small_model_policy=2;
end

% default initialization of the cluster weights


stop_saf=1e-6;		% stopping criterion 
	
% when param is empty :
% several starting conditions are tried and the one with the highest ll is used.
% the score for initial conditions takes into consideration the chunklett information, 
% but not the anti chunklet info.

if isempty(param)
   ch_num=length(unique(chunks))-1;
   nc_inds=find(chunks==-1);   % nc_inds - non chunkletted data indexes
   param=best_params_general(data,param,single_cov_mat_flag,chunks,ch_num,nc_inds);
end

% initialization

global pruned_flag;   
pruned_flag=0;
iteration=1;     % iteration.
global nets;	% the markov nets ( global used in calculate_partition_function2 and in calc_p_and_pll)
global OriginalNets;		% stores the full original nets ( before prunning ) in case 
								% we want to re prune.
original_k=k;	% save for case of model degradation
last_iteration_missed=0;	% a flag indicating whether the last iteration decreased ll
z=[];		% z record ( purity-acuracy scores )
global chunklet_sizes;	% for chunklet pf calculation
global anti_chunk_num;	% for anti chunklet pf aproximation calculation
anti_chunk_num=size(anti_chunks,1);
global anti_chunks_pruned;  % meaningfull only when reduce_arcs==2. the number of pruned negative constraints
anti_chunks_pruned=0;
global exit_flags;	% error messages : details on function remark
exit_flags=0;
global frozen_models;	% used in maximize. a vector containing the indexes of models
								% that were frozed in this session ( not neccesarily that they are frozen now )
frozen_models=[];
remove_models_flag=1;		% if small_model_policy is 1 this variable is 1 for 
                           % the first round, 0 for the other rounds.
grace_turn_flag=0;			% used when small model policy is 2, to indicate the grace turn.
                           
% divide the data into connected components and prepare graphs for each.
[ nets singles nc_inds c_inds oth chunklet_sizes ] = organize_constraints_info(chunks, anti_chunks , k );
OriginalNets=nets;

% start the EM iterations :

while (1)
   % E -step :
   % calculate the probabilities p( hidden i belongs to center j).
      
  
   [probabilities, logLikelihood ]=calc_p_and_pll2(data,param,single_cov_mat_flag,...
                                      c_inds,nc_inds,singles);
                                  
   
   % a calculation of the partition function 
   [ pf stam exp_const ]=calculate_partition_function2(cell2mat(param(:,3)),1,2);
   
   logLikelihood = logLikelihood -log(pf)+ exp_const;
  
   % keeping purity acuracy record
   if (~isempty(ClassLabels)) & (z_record_flag )
      if ~mm_flag
         [ stam , stam , ass ] =calc_ass_and_maxll(data,param,single_cov_mat_flag,c_inds,nc_inds,singles);
      end
      ass=ass(oth);	% from hidden assignment to observable assignment.
      [a2 , b2]=purity_accuracy2(ClassLabels,ass,original_k);
      %visual_k_means(data,ClassLabels,ass);

      z(iteration)=2*a2*b2/(a2+b2);
   end

   if(logLikelihood == -inf)
     z=[-1 -1 -1];
     results=[];
     logLikelihood=[];
     groups=[];
     anti_chunks_pruned=[];
     exit_flags=bitor(exit_flags,8);
     return
 end
   
   % record loglikelihood and check for stop
   ll(iteration)=logLikelihood;
  
   if iteration>1
      if  ( (ll(iteration)-ll(iteration-1))/abs(ll(iteration)) )2 % a model has been removed - give grace turn
      grace_turn_flag=1;
   else
      grace_turn_flag=0;
   end
        
   k=size(param,1);	% update the number of models ( models might have been removed ) 
   
   if iteration==2	% after the first iteration, models can be freezed
      if small_model_policy==1
         remove_models_flag=0;	% in itrations >=2 no models are removed.
      end
   end
   
   % change noam 23.6
   if isempty(param) % meaning: maximize has deleted all models
     z=[-1 -1 -1];
     results=[];
     logLikelihood=[];
     groups=[];
     anti_chunks_pruned=[];
     exit_flags=bitor(exit_flags,8);
     return
   end
   %%%% end change noam 23.6
end


% handling terrible bug ( which is expected at this version )
if ll(end)