Example #1
0
 def partition_function(self, batch_size, prec):
     """The exact value of Z calculated with precision prec. 
     Only feasible for small number of hidden units."""
     with decimal.localcontext() as ctx:
         if prec != 0:
             ctx.prec = prec
         batches = ml.common.util.pack_in_batches(all_states(self.n_hid),
                                               batch_size)
         if prec != 0:
             s = decimal.Decimal(0)
         else:
             allfhes = np.array([])
         seen_samples = 0L
         total_samples = 2L**self.n_hid
         for hid in batches:
             print >>stderr, "%i / %i           \r" % (seen_samples, total_samples),
             fhes = self.free_hidden_energy(hid)
             if prec != 0:
                 for fhe in gp.as_numpy_array(fhes):
                     p = decimal.Decimal(-fhe).exp()
                     s += p
             else:
                 allfhes = np.concatenate((allfhes, 
                                           -gp.as_numpy_array(fhes)))
             seen_samples += hid.shape[0]
         if prec != 0:
             return s
         else:
             return logsum(allfhes)
Example #2
0
File: ais.py Project: surban/ml
    def log_partition_function(self, betas, ais_runs, sampling_gibbs_steps=1,
                               mean_precision=0):     
        "Computes the partition function of the RBM"      
        start_time = time()
        
        assert betas[0] == 0 and betas[-1] == 1

        irbm = ml.rbm.RestrictedBoltzmannMachine(0,
                                              self.rbm.n_vis,
                                              self.rbm.n_hid,
                                              0)
        iw = gp.zeros(ais_runs)

        for i, beta in enumerate(betas):
            #print >>stderr, "%d / %d                       \r" % (i, len(betas)),
            if ml.common.show_progress and i % 1000 == 0:
                print "%d / %d" % (i, len(betas))

            beta = float(beta)

            # calculate log p_(i-1)(v)
            if beta != 0:
                lp_prev_vis = -irbm.free_energy(vis)

            # build intermediate RBM
            irbm.weights = beta * self.rbm.weights
            irbm.bias_hid = beta * self.rbm.bias_hid
            irbm.bias_vis = ((1.0-beta) * self.base_bias_vis + 
                             beta * self.rbm.bias_vis)

            # calculate log p_i(v_i)
            if beta != 0:
                lp_vis = -irbm.free_energy(vis)

            # update importance weight
            if beta != 0:
                iw += lp_vis - lp_prev_vis

            # sample v_(i+1)
            if beta == 0:
                vis = self.base_sample_vis(ais_runs)
            else:
                vis, _ = irbm.gibbs_sample(vis, sampling_gibbs_steps)

        # calculate mean and standard deviation       
        if mean_precision == 0:
            npiw = gp.as_numpy_array(iw) 
            w = npiw + self.base_log_partition_function()
            wmean = logsum(w) - np.log(w.shape[0])
            wsqmean = logsum(2 * w) - np.log(w.shape[0])
            wstd = logminus(wsqmean, wmean) / 2.0
            wmeanstd = wstd - 0.5 * math.log(w.shape[0])
            
            wmean_plus_3_std = logplus(wmean, wmeanstd + math.log(3.0))
            wmean_minus_3_std = logminus(wmean, wmeanstd + math.log(3.0), 
                                         raise_when_negative=False)
        else:
            with decimal.localcontext() as ctx:
                ctx.prec = mean_precision
                blpf = self.base_log_partition_function()

                ewsum = decimal.Decimal(0)
                ewsqsum = decimal.Decimal(0)
                for w in gp.as_numpy_array(iw):
                    ew = decimal.Decimal(w + blpf).exp()
                    ewsum += ew
                    ewsqsum += ew ** 2
                ewmean = ewsum / iw.shape[0]
                ewsqmean = ewsqsum / iw.shape[0]
                ewstd = (ewsqmean - ewmean**2).sqrt()
                ewstdmean = ewstd / math.sqrt(iw.shape[0])

                wmean = float(ewmean.ln())
                wmean_plus_3_std = float((ewmean + 3*ewstdmean).ln())
                if ewmean - 3*ewstd > 0:
                    wmean_minus_3_std = float((ewmean - 3*ewstdmean).ln())
                else:
                    wmean_minus_3_std = float('-inf')

        end_time = time()
        print "AIS took %d s" % (end_time - start_time)

        return wmean, wmean_minus_3_std, wmean_plus_3_std