%	 Copyright (C) 2008  Frhwirth-Schnatter
%	 Copyright (C) 2011  Bluder, Plankensteiner
%
%    This program is free software: you can redistribute it and/or modify
%    it under the terms of the GNU General Public License as published by
%    the Free Software Foundation, either version 3 of the License, or
%    (at your option) any later version.
%
%    This program is distributed in the hope that it will be useful,
%    but WITHOUT ANY WARRANTY; without even the implied warranty of
%    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%    GNU General Public License for more details.
%
%    You should have received a copy of the GNU General Public License
%    along with this program.  If not, see <http://www.gnu.org/licenses/>.

function [marlik,varargout] = mcmcbf_MOE(data,mcmcout,varargin)
% computes the marginal likelihood using simulation-based methods

% modified such that method can only handle regression
% models or MoE models based on univariate Gaussian data
          
rand('state',sum(100*clock)) ; 
randn('state',sum(100*clock)) ; 


if ~isfield(data,'empty') data.empty=false; end

if ~isfield(mcmcout.model,'d') 
    warn('Function only implemenetd for linear regression models and MoE models.')
end 

if ~isfield(mcmcout,'post')
    warn('No posterior moments available, set .mcmcstorepost=true before running mixturemcmc')
    marlik=[];return
end

if ~isfield(data,'y')
    warn('The field y is obligatory in the structure array defining the data when calling the function mixturebf')
    marlik=[];return
end 
ibycolumn=isfield(data,'bycolumn');
if ibycolumn  ibycolumn=data.bycolumn; end  % ibycolumn true: data stored by column
if ibycolumn      data.y=data.y'; data.bycolumn ='false'; end
if ~isfield(data,'N') data.N=size(data.y,2); end

M=mcmcout.M;
 
model=mcmcout.model;
  
if ~isfield(model,'indexdf') model.df=0; else model.df=size(model.indexdf,1)*size(model.indexdf,2); end
 
if ~isfield(model,'K')  model.K =1; end
K=model.K;
if and(K>1,~isfield(model,'indicmod'))  model.indicmod.dist='Multinomial'; model.indicmod.T=1; model.indicmod.cat=K;end
mix.dist=model.dist;
if all(mix.dist(1:6)=='Normal')
    if isfield(model,'parfix')  
        if isfield(model.parfix,'mu')
            mufix=model.parfix.mu;
        else
            mufix=false;
        end
    else
        mufix=false;
    end
else
    warn('Function is only implemented for normal distributions.')
end 

if ~isfield(model,'error') % default: switching variane
    model.error='switch';
end

if ~isfield(model,'indicfix')  model.indicfix=false;   end
if ~isfield(model,'indicmod')
    model.indicmod.dist='Multinomial';
elseif ~isfield(model.indicmod,'init')
    model.indicmod.init='ergodic';
end

if ~isfield(data,'t0') datat0=1; else  datat0=data.t0; end
if and (model.d==1,~isfield(data,'X')) 
    data.X=ones(1,data.N); 
end

if datat0>1
    dataar=data; % original data, needed for likelihood computaion, if t0 differs between estimation and BF calculations
end

if ~isfield(mcmcout.log,'t0') logt0=1; else  logt0=mcmcout.log.t0; end
if logt0~=datat0
    computelogmcmc=true;
    loglikmc=zeros(M,1);
else
    computelogmcmc=false;  clear dataar;
    loglikmc=mcmcout.log.mixlik;
end
 
prior=mcmcout.prior;
if nargin==3
    if isfield(varargin{1},'Smax')
        Smax=varargin{1}.Smax;
    else
        Smax=2000;  % maximum numbers of components in the importance density
    end
    if isfield(varargin{1},'method')
        warn(['Method ' varargin{1}.method ' not implemented in this function.']); return
    else
        method='random';
    end
    if isfield(varargin{1},'M0')
        M0=varargin{1}.M0;
    else
        if le(K,3) 
            M0=100;
        elseif K==4
            M0=25;
        else 
            M0=5;
        end
    end
else
    Smax=2000;
    method='random';
     
    if le(K,3) 
        M0=100; 
    elseif K==4
        M0=25;
    else
        M0=5; 
    end
end 

%%  construct the importance density q

simstud=false;
%simstud=true;dfq=10;  % use stud-t instead of normal distributions (currently implemented only for regression models)

S=round(M0*exp(gammaln(K+1)));
S=min([S Smax M]);  % truncated, if S is too large

%indexchoose=[M-S+1:M];  % extract the posterior moments corresponding to the last S components
index=randperm(M);
indexchoose=sort(index([M-S+1:M])); % extract M0 randomly chosen posterior moments

mcmcoutq=mcmcsubseq_MOE(mcmcout,indexchoose);
post=mcmcoutq.post; clear mcmcoutq;

qout.log.mixprior=zeros(M,1); qout.log.mixlik=zeros(M,1); qout.log.q=zeros(M,1);  mcmcout.log.q=zeros(M,1);

%% compute  bayes factors
%  sample from q and evaluate importance density at the q and the MCMC
%  draws

t0 = clock;tlast=t0;

% choose the components for sampling from the mixture importance density
 
qs = simuni(S,M);
    
for m=1:M
    if isfield(model,'weight') 
        model=rmfield(model,'weight'); 
    end
    
    % modified by: Olivia Bluder Sept 29, 2010
    if model.indicmod.dist(1:6)=='FixedW'
        model.weight = mcmcout.prior.weight;
    end       
    % end modification
      
    if isfield(model,'par')
        model=rmfield(model,'par');
    end

    if etime(clock,tlast)>60
        ext=fix(etime(clock,t0)/m*(M-m));
        ['estimated completion in about ' num2str(fix(ext/60))  ' minutes  and ' num2str(ext-fix(ext/60)*60) ' seconds']
        tlast = clock;
    end

    qq=zeros(S,1);qmc=zeros(S,1);

    %% weigth distribution
    
    if and(K>1,~model.indicfix)
        if model.indicmod.dist(1:6)=='Multin'
            model.weight=dirichsim(post.weight(qs(m),:));
            qq= qq + dirichpdflog(post.weight,model.weight);
            qmc=qmc + dirichpdflog(post.weight,mcmcout.weight(m,:));
        end
    else
        model.weight=1;
    end     
    

%%  mixture and Markov switching regression
    
        if isfield(model,'error')
            if model.error=='switch'
                model.par.sigma=1./prodgamsim(struct('a',post.par.sigma.c(qs(m),:),'b',post.par.sigma.C(qs(m),:)));
                % evaluate q
                modseq=struct('a',post.par.sigma.c,'b',post.par.sigma.C);
                qq  = qq +prodinvgampdflog(modseq,model.par.sigma);
                qmc = qmc+prodinvgampdflog(modseq,mcmcout.par.sigma(m,:));
            end
        end
 
        % move this to the top - do it only once
        
        % modified by: Olivia Bluder May 18, 2011
        if m==1  % compute the information matrix , the cholesky decomposition and the determinante to speed uo computation
        % end modification
            post.par.beta.Binv=0*post.par.beta.B; post.par.beta.Bchol=0*post.par.beta.B; post.par.beta.logdet=zeros(S,K);
            if isfield(model,'indexdf')
                model.df=size(model.indexdf,1)*size(model.indexdf,2);
                post.par.alpha.Ainv=0*post.par.alpha.A; post.par.alpha.Achol=0*post.par.alpha.A; post.par.alpha.logdet=zeros(S,1);
            else
                model.df=0;
            end

            for i=1:S;
                for k=1:K;
                    Binv=inv(squeeze(post.par.beta.B(i,:,:,k)));
                    post.par.beta.Binv(i,:,:,k)=Binv;
                    post.par.beta.logdet(i,k)=log(det(Binv));
                    post.par.beta.Bchol(i,:,:,k)=chol(squeeze(post.par.beta.B(i,:,:,k)))';
                end;
                if model.df>0
                    Ainv=inv(squeeze(post.par.alpha.A(i,:,:)));
                    post.par.alpha.Ainv(i,:,:)=Ainv;
                    post.par.alpha.Achol(i,:,:)=chol(squeeze(post.par.alpha.A(i,:,:)))';
                    post.par.alpha.logdet(i)=log(det(Ainv));
                end
            end;
        end 
    
        % modified by: Olivia Bluder May 18, 2011

            if and(K>1,size(post.par.beta.b,2)>1)
                if ~simstud
                    model.par.beta=prodnormultsim(struct('mu',squeeze(post.par.beta.b(qs(m),:,:)),'sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:,:))));
                else
                    model.par.beta=prodstudmultsim(struct('mu',squeeze(post.par.beta.b(qs(m),:,:)),'sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:,:)),'df',dfq*ones(1,K)));
                end
            elseif K>1

                if ~simstud
                    model.par.beta=prodnormultsim(struct('mu',squeeze(post.par.beta.b(qs(m),:,:))','sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:,:))'));
                else
                    model.par.beta=prodstudmultsim(struct('mu',squeeze(post.par.beta.b(qs(m),:,:))','sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:,:))','df',dfq*ones(1,K)));
                end

            else

                if ~simstud
                    model.par.beta=prodnormultsim(struct('mu',post.par.beta.b(qs(m),:)','sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:))));
                else

                    model.par.beta=prodstudmultsim(struct('mu',post.par.beta.b(qs(m),:)','sigmachol',squeeze(post.par.beta.Bchol(qs(m),:,:)),'df',dfq));
                end
            end 


        if and((model.d-model.df)==1,K==1)
            if simstud
                modseq=struct('mu',post.par.beta.b,'sigmainv',1./post.par.beta.B,'logdet',post.par.beta.logdet,'df',repmat(dfq,size(post.par.beta.logdet)));
                qq  = qq +prodstudmultpdflog(modseq,model.par.beta);
                qmc  = qmc + prodstudmultpdflog(modseq,mcmcout.par.beta(m,:));
            else
                modseq=struct('mu',post.par.beta.b,'sigma',post.par.beta.B,'logdet',post.par.beta.logdet);

                qq  = qq +prodnorpdflog(modseq,model.par.beta);
                qmc  = qmc + prodnorpdflog(modseq,mcmcout.par.beta(m,:));
            end
        elseif and((model.d-model.df)>1,K>1)
            modseq=struct('mu',post.par.beta.b,'sigmainv',post.par.beta.Binv,'logdet',post.par.beta.logdet);
            if simstud
                modseq.df=repmat(dfq,size(post.par.beta.logdet));
                qq  = qq +prodstudmultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodstudmultpdflog(modseq,squeeze(mcmcout.par.beta(m,:,:)));

            else
                qq  = qq +prodnormultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodnormultpdflog(modseq,squeeze(mcmcout.par.beta(m,:,:)));
            end
        elseif and((model.d-model.df)==1,K>1)
            modseq=struct('mu',squeeze(post.par.beta.b),'sigmainv',squeeze(post.par.beta.Binv),'logdet',squeeze(post.par.beta.logdet));
            if simstud
                modseq.df=repmat(dfq,size(post.par.beta.logdet));
                qq  = qq +prodstudmultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodstudmultpdflog(modseq,squeeze(mcmcout.par.beta(m,:,:))');
            else
                qq  = qq +prodnormultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodnormultpdflog(modseq,squeeze(mcmcout.par.beta(m,:,:))');
            end
        else
            modseq=struct('mu',post.par.beta.b,'sigmainv',post.par.beta.Binv,'logdet',post.par.beta.logdet);
            if simstud
                modseq.df=dfq;
                qq  = qq +prodstudmultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodstudmultpdflog(modseq,squeeze(mcmcout.par.beta(m,:))');

            else
                qq  = qq +prodnormultpdflog(modseq,model.par.beta);
                qmc  = qmc +prodnormultpdflog(modseq,squeeze(mcmcout.par.beta(m,:))');
            end
        end 

        if model.df>0
            model.par.alpha=prodnormultsim(struct('mu',post.par.alpha.a(qs(m),:)','sigmachol',squeeze(post.par.alpha.Achol(qs(m),:,:))));
            modseq=struct('mu',post.par.alpha.a,'sigmainv',post.par.alpha.Ainv,'logdet',post.par.alpha.logdet);
            qq  = qq +prodnormultpdflog(modseq,model.par.alpha);
            qmc  = qmc +prodnormultpdflog(modseq,squeeze(mcmcout.par.alpha(m,:))');
        end

%% compute the functional value of the importance density
 
    qqmax=max(qq);
    qout.log.q(m)=qqmax+log(mean(exp(max(qq-qqmax,-1e300))));
    qmcmax=max(qmc);
    mcmcout.log.q(m)=qmcmax+log(mean(exp(max(qmc-qmcmax,-1e300))));
 

%% compute likelihood and prior for the qsample
    if data.empty
        qout.log.mixlik(m)=0;
    else 
        % modified by: Olivia Bluder Sept 29, 2010
        qout.log.mixlik(m)=likelihoodeval_MOE(data,model);
        % end modification
        if computelogmcmc
            modelmcmc=mcmcextract_MOE(mcmcout,m);
            loglikmc(m)=likelihoodeval_MoE(dataar,modelmcmc);
        end
    end 
    % modified by: Olivia Bluder Sept 30, 2010
    qout.log.mixprior(m)=prioreval_MOE(model,prior);
    % end modification
end
 
%% compute the marginal likelihoods
  
priormc=mcmcout.log.mixprior;
qmc= mcmcout.log.q;
loglikq=qout.log.mixlik;
priorq=qout.log.mixprior;
qq=qout.log.q;
Mbs=M;

%% Marginal likelihood based on importance sampling

ratio = loglikq+priorq-qq;
ratiomax=max(ratio);
mllogmi = ratiomax+log(mean(exp(ratio-ratiomax)));

%% Marginal likelihood based on reciprocal importance sampling

ratio = qmc-loglikmc-priormc;
ratiomax=max(ratio);
mllogri = -1*(ratiomax+log(mean(exp(ratio-ratiomax))));

%% Marginal likelihood based on iterative bridge  sampling, starting from ri and mi

maxit=1000;
mllogbs=zeros(maxit,2);
mllogbs(1,1)=real(mllogmi);
mllogbs(1,2)=real(mllogri);

for i=2:maxit
    logpostmc=loglikmc+priormc-mllogbs(i-1,2);
    logpostq=loglikq+priorq-mllogbs(i-1,2);
    maxqmc=max([qq logpostq],[],2);
    rq=logpostq-maxqmc-log(exp(log(Mbs)+logpostq-maxqmc)+exp(log(Mbs)+qq-maxqmc));
    maxqmc=max([qmc logpostmc],[],2);
    rmc=qmc-maxqmc-log(exp(log(Mbs)+logpostmc-maxqmc)+exp(log(Mbs)+qmc-maxqmc));
    mllogbs(i,2)=mllogbs(i-1,2)+max(rq)+log(mean(exp(max(rq-max(rq),-1e15))))-max(rmc)-log(mean(exp(max(rmc-max(rmc),-1e15))));

    logpostmc=loglikmc+priormc-mllogbs(i-1,1);
    logpostq=loglikq+priorq-mllogbs(i-1,1);
    maxqmc=max([qq logpostq],[],2);
    rq=logpostq-maxqmc-log(exp(log(Mbs)+logpostq-maxqmc)+exp(log(Mbs)+qq-maxqmc));
    maxqmc=max([qmc logpostmc],[],2);
    rmc=qmc-maxqmc-log(exp(log(Mbs)+logpostmc-maxqmc)+exp(log(Mbs)+qmc-maxqmc));
    mllogbs(i,1)=mllogbs(i-1,1)+max(rq)+log(mean(exp(max(rq-max(rq),-1e15))))-max(rmc)-log(mean(exp(max(rmc-max(rmc),-1e15))));
end

%% store the results
  
marlik.ri=mllogbs(1,2);
marlik.is=mllogbs(1,1);
marlik.bs=mllogbs(end,2); 

marlik.log.loglikmc=loglikmc;
marlik.log.priormc=priormc;
marlik.log.qmc=qmc;
marlik.log.loglikq=qout.log.mixlik;
marlik.log.priorq= qout.log.mixprior;
marlik.log.qq=qout.log.q;

%% compute standard errors
      
[sebs,ach,acr] = marginallikelihood_eval(marlik,false,0);
marlik.se.bs=sebs(1:3);
marlik.se.is=sebs(4);
marlik.se.ri=sebs(5);

if nargout==2
    mcmcout.marlik=marlik;
    varargout{1}=mcmcout;
end