コード例 #1
0
class NADE_q():
    def __init__(self, num_dims, num_hidden, temp=.1, hard=False):
        self.nade = NADE(num_dims, num_hidden, temperature=temp, hard=hard)

    # for visualization
    def get_v(self, num_samples):
        return self.nade.sample(num_samples).numpy()

    def log_prob(self, num_samples):
        v = self.nade.sample(num_samples)
        logp_v = self.nade.logprobs(v)
        return logp_v, v

    def params(self):
        return self.nade.params()
コード例 #2
0
class NADEDEC():
    def __init__(self,
                 dim_zh2,
                 dims_h2h1,
                 dims_h1v,
                 activation_fn=tf.nn.relu,
                 temp=.1,
                 hard=False,
                 train_pi=True):
        self.prior_h2 = NADE(dim_zh2[1], dim_zh2[0], temp)
        self.dec_h2h1 = DEC(dims_h2h1,
                            name='h2h1',
                            activation_fn=activation_fn,
                            temp=temp,
                            hard=hard)
        self.dec_h1v = DEC(dims_h1v,
                           name='h1v',
                           activation_fn=activation_fn,
                           temp=temp,
                           hard=hard)

    # for visualization
    def get_v(self, num_samples):
        h2 = self.prior_h2.sample(num_samples)
        h1 = self.dec_h2h1.get_sample_hard(h2)
        v = self.dec_h1v.get_sample_hard(h1)
        return v.numpy()

    # for training
    def get_h2_h1_v(self, num_samples):
        h2 = self.prior_h2.sample(num_samples)
        h1 = self.dec_h2h1.get_sample_soft(h2)
        v = self.dec_h1v.get_sample_soft(h1)
        return h2, h1, v

    def log_prob(self, num_samples):
        h2 = self.prior_h2.sample(num_samples)
        logp_h2 = self.prior_h2.logprobs(h2)
        logp_h1_given_h2, h1 = self.dec_h2h1.log_conditional_prob(h2)
        logp_v_given_h1, v = self.dec_h1v.log_conditional_prob(h1)
        return logp_v_given_h1 + logp_h1_given_h2 + logp_h2, h2, h1, v

    def params(self):
        return self.prior_h2.params() + self.dec_h1v.params(
        ) + self.dec_h2h1.params()
コード例 #3
0
class NADESBN():
    def __init__(self,
                 dim_z2h,
                 dim_h2v,
                 activation_fn=tf.nn.relu,
                 temp=.1,
                 hard=False,
                 train_pi=True):
        self.prior_h = NADE(dim_z2h[1], dim_z2h[0], temp)
        self.dim_h2v = dim_h2v
        self.w_h2v, self.b_h2v = [], []
        if activation_fn == tf.nn.relu:
            const = 1.0
        else:
            const = 4.0

        assert len(self.dim_h2v) > 1
        for d1, d2, i in zip(dim_h2v[:-1], dim_h2v[1:],
                             xrange(1, len(self.dim_h2v))):
            self.w_h2v.append(
                tfe.Variable(tf_xavier_init(d1, d2, const=const),
                             name='sbn.w_h2v.' + str(i)))
            self.b_h2v.append(
                tfe.Variable(tf.zeros([d2]),
                             dtype=tf.float32,
                             name='sbn.b_h2v.' + str(i)))

        self.activation_fn = activation_fn
        self.temp = temp
        self.hard = hard

    def h2v(self, h):
        v = h
        for w, b, i in zip(self.w_h2v, self.b_h2v, xrange(len(self.w_h2v))):
            v = tf.matmul(v, w) + b
            if i == len(self.w_h2v) - 1:
                v = tf.nn.sigmoid(v)
            else:
                v = self.activation_fn(v)
        return v

    def get_h(self, num_samples):
        h = self.prior_h.sample(num_samples)
        return h.numpy()

    # for visualization
    def get_v(self, num_samples):
        h = self.prior_h.sample(num_samples)
        v = sample_from_bernoulli(self.h2v(h))
        return v.numpy()

    # for training
    def get_h_and_v(self, num_samples):
        h = self.prior_h.sample(num_samples)
        v = gumbel_sigmoid_sample(self.h2v(h), self.temp, self.hard)
        return h, v

    def logprob_v_and_h(self, h, v):
        logp_h = self.prior_h.logprobs(h)
        logp_v_given_h = bernoulli_log_likelihood(v, self.h2v(h))
        return logp_v_given_h + logp_h

    def params(self):
        return self.prior_h.params() + tuple(self.w_h2v + self.b_h2v)