www.pudn.com > ConstrainedEM.zip > infer_local_net.m


function [ probs, loglik , assignment ]=infer_local_net(engine,pots,max_flag) 
% this function uses the potentials calculated from the data (pots) to infer the 
% probabilities p( hidden i belongs to model j ) by belief net propagation
% when max_flag is set 
%			probs - a {1,0} probability matrix specifing the assignment
%			assignment - the assignment as a label list
% when max flag is 0
%			probs - a p( hidden i belongs to model j ) matrix
%			assignment is 0.

k=size(pots,2);
N = length(engine.mnet);	% number of hidden nodes
ns = [ ones(N,1)*k ];
assignment=0;
if nargin<3
   engine.maximize=0;
   max_flag=0;
else
	engine.maximize=max_flag;   
end

% prepare initial clique potentials
cliques = engine.cliques;
C = length(cliques);
clpot = cell(1,C);
for i=1:C
   clpot{i} = dpot(cliques{i}, k*ones(length(cliques{i}),1));
	% mk_initial_pot( 'd', cliques{i}, k*ones(length(cliques{i}),1),[],[] );
end

% enter the evidence
for i=1:N
   c=engine.clq_ass_to_node(i);
   tpot.domain=i;
   tpot.T=pots(i,:);
   tpot.sizes=k;
   clpot{c}=multiply_by_pot( clpot{c}, tpot );
end

% enter the constraints potentials
[x y]=find(triu(engine.mnet)==1);
for i=1:length(y)
   c=engine.clq_ass_to_const(x(i),y(i));
   tpot.domain=[ x(i) y(i) ];
   tpot.T=ones(k,k)-eye(k);
   tpot.sizes=[ k k ];
   clpot{c}=multiply_by_pot( clpot{c}, tpot );
end

seppot = cell(C,C); % implicitely initialized to 1

[clpot, seppot nll ] = my_collect_evidence(engine, clpot, seppot);
if max_flag		% maximize - get a hard assignment
   [ assignment, loglik ] = max_distribute_evidence(engine, clpot, seppot);	% set assignment backward 
   loglik=loglik+nll;
   probs=zeros(k,N);
   probs(  (0:k:(N-1)*k)'+ assignment )=1;
   probs=probs';
else				% calculate mean probabilities - get a soft assignment
   [clpot, seppot ] = my_distribute_evidence(engine, clpot, seppot);	% propagate evidence backward
   loglik=nll;
   
   engine.clpot = clpot;
   
   % marginalize and calculate the p( hidden i belongs to  model j ) 
   for i=1:N
      marg = my_marginal_nodes(engine, i);
      probs(i,:)=marg.T';
   end
end