www.pudn.com > gmm_utilities.zip > gmm_em.m


function [g, fit] = gmm_em(s, g, N) 
%function [g, fit] = gmm_em(s, g, N) 
% 
% INPUTS: 
%   s - samples 
%   g - initial gmm 
%   N - number of iterations of EM 
% 
% OUTPUT: 
%   g - resultant gmm 
%   fit - negative log-likelihood of fit 
% 
% REFERENCES: 
%   Figueiredo et al, On Fitting Mixture Models, 1999. Section 2.2, Equations 6 to 9. 
%   Ian Nabney, NetLab, http://www.ncrg.aston.ac.uk/netlab/index.php 
% 
% Tim Bailey 2005. 
 
NM = size(g.x, 2); % number of mixtures 
NS = size(s, 2);   % number of data samples 
g.w = g.w / sum(g.w); 
 
while N > 0 
    N = N - 1; 
     
    % E-step: compute assignment likelihood 
    for i=1:NM 
        v = s - repvec(g.x(:,i), NS); 
        w(i,:) = g.w(i) * gauss_likelihood(v, g.P(:,:,i)); 
    end 
     
    wsr = sum(w, 1);          % sum across rows: wsr(i) = sum(w(1:NM,i)), where i in NS 
    fit = -sum(log(wsr));     % negative log-likelihood of fit (from NetLab gmmem.m and gmmpost.m) 
    wsr = checkzeros(wsr);    % avoid divide-by-zero 
    % TODO: need to adjust also w for positions where wsr is zero (see NetLab 3.3 gmmpost.m) 
    w = w ./ reprow(wsr, NM); % normalise columns: sum(w(1:NM,i)) == 1, where i in NS 
     
    % M-step: compute new (x,P,w) values for gmm 
    wsc = sum(w, 2);      % sum across columns: wsc(i) = sum(w(i,1:NS)), where i in NM 
    g.w = wsc / sum(wsc); % note, sum(wsc) should equal NS (due to normalisation above) 
     
    for i=1:NM 
        w_norm = w(i,:) ./ wsc(i); % note, wsc(i) is equal to sum(w(i,:)), so sum(w_norm) = 1 
        % TODO: above line has a possible divide-by-zero error, fix it. 
        [g.x(:,i), g.P(:,:,i)] = sample_mean_weighted(s, w_norm); 
        g.P(:,:,i) = checkP(g.P(:,:,i)); % check P has not collapsed 
    end 
end 
 
% 
% 
 
% Replicate a column-vector N times 
function x = repvec(x,N) 
x = x(:, ones(1,N)); 
 
% Replicate a row-vector N times 
function x = reprow(x,N) 
x = x(ones(1,N), :); 
 
% Check array for zero terms, change them to ones 
function x = checkzeros(x) 
i = find(x==0); 
x(i) = 1; 
% Alternatives:  
%   x = x + (x==0); 
%or if ~isempty(i), x(i) = 1; end 
 
% Check covariance for collapse, if so, inflate it 
function P = checkP(P) 
%if any(abs(diag(P)) < 1e-9) % check trace 
if det(P) < eps % check determinant 
    P = eye(size(P)); 
end 
% TODO: improve checkP. NetLab uses measure: if min(svd(P)) < MINCOV, P=Pinit; end 
% Where Pinit is the original covariance for that component.