class NADE_q(): def __init__(self, num_dims, num_hidden, temp=.1, hard=False): self.nade = NADE(num_dims, num_hidden, temperature=temp, hard=hard) # for visualization def get_v(self, num_samples): return self.nade.sample(num_samples).numpy() def log_prob(self, num_samples): v = self.nade.sample(num_samples) logp_v = self.nade.logprobs(v) return logp_v, v def params(self): return self.nade.params()
class NADEDEC(): def __init__(self, dim_zh2, dims_h2h1, dims_h1v, activation_fn=tf.nn.relu, temp=.1, hard=False, train_pi=True): self.prior_h2 = NADE(dim_zh2[1], dim_zh2[0], temp) self.dec_h2h1 = DEC(dims_h2h1, name='h2h1', activation_fn=activation_fn, temp=temp, hard=hard) self.dec_h1v = DEC(dims_h1v, name='h1v', activation_fn=activation_fn, temp=temp, hard=hard) # for visualization def get_v(self, num_samples): h2 = self.prior_h2.sample(num_samples) h1 = self.dec_h2h1.get_sample_hard(h2) v = self.dec_h1v.get_sample_hard(h1) return v.numpy() # for training def get_h2_h1_v(self, num_samples): h2 = self.prior_h2.sample(num_samples) h1 = self.dec_h2h1.get_sample_soft(h2) v = self.dec_h1v.get_sample_soft(h1) return h2, h1, v def log_prob(self, num_samples): h2 = self.prior_h2.sample(num_samples) logp_h2 = self.prior_h2.logprobs(h2) logp_h1_given_h2, h1 = self.dec_h2h1.log_conditional_prob(h2) logp_v_given_h1, v = self.dec_h1v.log_conditional_prob(h1) return logp_v_given_h1 + logp_h1_given_h2 + logp_h2, h2, h1, v def params(self): return self.prior_h2.params() + self.dec_h1v.params( ) + self.dec_h2h1.params()
class NADESBN(): def __init__(self, dim_z2h, dim_h2v, activation_fn=tf.nn.relu, temp=.1, hard=False, train_pi=True): self.prior_h = NADE(dim_z2h[1], dim_z2h[0], temp) self.dim_h2v = dim_h2v self.w_h2v, self.b_h2v = [], [] if activation_fn == tf.nn.relu: const = 1.0 else: const = 4.0 assert len(self.dim_h2v) > 1 for d1, d2, i in zip(dim_h2v[:-1], dim_h2v[1:], xrange(1, len(self.dim_h2v))): self.w_h2v.append( tfe.Variable(tf_xavier_init(d1, d2, const=const), name='sbn.w_h2v.' + str(i))) self.b_h2v.append( tfe.Variable(tf.zeros([d2]), dtype=tf.float32, name='sbn.b_h2v.' + str(i))) self.activation_fn = activation_fn self.temp = temp self.hard = hard def h2v(self, h): v = h for w, b, i in zip(self.w_h2v, self.b_h2v, xrange(len(self.w_h2v))): v = tf.matmul(v, w) + b if i == len(self.w_h2v) - 1: v = tf.nn.sigmoid(v) else: v = self.activation_fn(v) return v def get_h(self, num_samples): h = self.prior_h.sample(num_samples) return h.numpy() # for visualization def get_v(self, num_samples): h = self.prior_h.sample(num_samples) v = sample_from_bernoulli(self.h2v(h)) return v.numpy() # for training def get_h_and_v(self, num_samples): h = self.prior_h.sample(num_samples) v = gumbel_sigmoid_sample(self.h2v(h), self.temp, self.hard) return h, v def logprob_v_and_h(self, h, v): logp_h = self.prior_h.logprobs(h) logp_v_given_h = bernoulli_log_likelihood(v, self.h2v(h)) return logp_v_given_h + logp_h def params(self): return self.prior_h.params() + tuple(self.w_h2v + self.b_h2v)