def partition_function(self, batch_size, prec): """The exact value of Z calculated with precision prec. Only feasible for small number of hidden units.""" with decimal.localcontext() as ctx: if prec != 0: ctx.prec = prec batches = ml.common.util.pack_in_batches(all_states(self.n_hid), batch_size) if prec != 0: s = decimal.Decimal(0) else: allfhes = np.array([]) seen_samples = 0L total_samples = 2L**self.n_hid for hid in batches: print >>stderr, "%i / %i \r" % (seen_samples, total_samples), fhes = self.free_hidden_energy(hid) if prec != 0: for fhe in gp.as_numpy_array(fhes): p = decimal.Decimal(-fhe).exp() s += p else: allfhes = np.concatenate((allfhes, -gp.as_numpy_array(fhes))) seen_samples += hid.shape[0] if prec != 0: return s else: return logsum(allfhes)
def log_partition_function(self, betas, ais_runs, sampling_gibbs_steps=1, mean_precision=0): "Computes the partition function of the RBM" start_time = time() assert betas[0] == 0 and betas[-1] == 1 irbm = ml.rbm.RestrictedBoltzmannMachine(0, self.rbm.n_vis, self.rbm.n_hid, 0) iw = gp.zeros(ais_runs) for i, beta in enumerate(betas): #print >>stderr, "%d / %d \r" % (i, len(betas)), if ml.common.show_progress and i % 1000 == 0: print "%d / %d" % (i, len(betas)) beta = float(beta) # calculate log p_(i-1)(v) if beta != 0: lp_prev_vis = -irbm.free_energy(vis) # build intermediate RBM irbm.weights = beta * self.rbm.weights irbm.bias_hid = beta * self.rbm.bias_hid irbm.bias_vis = ((1.0-beta) * self.base_bias_vis + beta * self.rbm.bias_vis) # calculate log p_i(v_i) if beta != 0: lp_vis = -irbm.free_energy(vis) # update importance weight if beta != 0: iw += lp_vis - lp_prev_vis # sample v_(i+1) if beta == 0: vis = self.base_sample_vis(ais_runs) else: vis, _ = irbm.gibbs_sample(vis, sampling_gibbs_steps) # calculate mean and standard deviation if mean_precision == 0: npiw = gp.as_numpy_array(iw) w = npiw + self.base_log_partition_function() wmean = logsum(w) - np.log(w.shape[0]) wsqmean = logsum(2 * w) - np.log(w.shape[0]) wstd = logminus(wsqmean, wmean) / 2.0 wmeanstd = wstd - 0.5 * math.log(w.shape[0]) wmean_plus_3_std = logplus(wmean, wmeanstd + math.log(3.0)) wmean_minus_3_std = logminus(wmean, wmeanstd + math.log(3.0), raise_when_negative=False) else: with decimal.localcontext() as ctx: ctx.prec = mean_precision blpf = self.base_log_partition_function() ewsum = decimal.Decimal(0) ewsqsum = decimal.Decimal(0) for w in gp.as_numpy_array(iw): ew = decimal.Decimal(w + blpf).exp() ewsum += ew ewsqsum += ew ** 2 ewmean = ewsum / iw.shape[0] ewsqmean = ewsqsum / iw.shape[0] ewstd = (ewsqmean - ewmean**2).sqrt() ewstdmean = ewstd / math.sqrt(iw.shape[0]) wmean = float(ewmean.ln()) wmean_plus_3_std = float((ewmean + 3*ewstdmean).ln()) if ewmean - 3*ewstd > 0: wmean_minus_3_std = float((ewmean - 3*ewstdmean).ln()) else: wmean_minus_3_std = float('-inf') end_time = time() print "AIS took %d s" % (end_time - start_time) return wmean, wmean_minus_3_std, wmean_plus_3_std