def gumbel_reparmeterization(logits_z,
                             tau,
                             rnd_sample=None,
                             hard=True,
                             eps=1e-9):
    '''
    The gumbel-softmax reparameterization
    '''
    latent_size = logits_z.get_shape().as_list()[1]

    # Prior
    p_z = d.OneHotCategorical(
        probs=tf.constant(1.0 / latent_size, shape=[latent_size]))
    # p_z = d.RelaxedOneHotCategorical(probs=tf.constant(1.0/latent_size,
    #                                                    shape=[latent_size]),
    #                                  temperature=10.0)
    # p_z = 1.0 / latent_size
    # log_p_z = tf.log(p_z + eps)

    with st.value_type(st.SampleValue()):
        q_z = st.StochasticTensor(
            d.RelaxedOneHotCategorical(temperature=tau, logits=logits_z))
        q_z_full = st.StochasticTensor(d.OneHotCategorical(logits=logits_z))

    reduce_index = [1] if len(logits_z.get_shape().as_list()) == 2 else [1, 2]
    kl = d.kl(q_z_full.distribution, p_z, allow_nan_stats=False)
    if len(shp(kl)) > 1:
        return [q_z, tf.reduce_sum(kl, reduce_index)]
    else:
        return [q_z, kl]
Пример #2
0
def inference_model(observations, is_training, latent_layer_dims, tau,
                    batch_size, nn_layers):
    samples = [observations]
    q_log_likelihoods = []
    mean_list = []
    var_list = []

    for i in range(len(latent_layer_dims)):
        mu, sigma_sq = inference_net(samples[i], is_training, nn_layers[0],
                                     nn_layers[1], latent_layer_dims[i])

        sample = sample_gaussian(latent_layer_dims[i], mu, sigma_sq,
                                 batch_size)
        sample_ll = gaussian_log_likelihood(sample, mu, sigma_sq)

        mean_list.append(mu)
        var_list.append(sigma_sq)
        q_log_likelihoods.append(sample_ll)
        samples.append(sample)

    logits_qy = slim.fully_connected(mlp(samples[-1], is_training,
                                         nn_layers[0], nn_layers[1]),
                                     10,
                                     activation_fn=None)
    q_y = dist.RelaxedOneHotCategorical(tau, logits=logits_qy)

    sample = q_y.sample()
    samples.append(sample)
    q_log_likelihoods.append(tf.log(sample + TOL))
    samples.remove(samples[0])

    return samples, mean_list, var_list, q_log_likelihoods
Пример #3
0
 def sample_q(self, k, mode):
     if mode == tf.estimator.ModeKeys.TRAIN:
         z_dist = distributions.RelaxedOneHotCategorical(
             self.temp, logits=self.q_dist.logits)
         z_NK = z_dist.sample(k)
     elif mode == tf.estimator.ModeKeys.EVAL:
         z_NK = tf.to_float(self.q_dist.sample(k))
     return tf.reshape(z_NK, [k, -1, self.N * self.K])
Пример #4
0
def relaxed_test():
    var = tf.Variable([-5, -10], dtype=tf.float32)
    cat = dist.RelaxedOneHotCategorical(2.0, logits=var)

    loss = tf.reduce_sum((tf.cast(cat.sample([10]), tf.float32)), axis=0)[0]
    train_op = tf.train.AdamOptimizer().minimize(loss)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(100):
            print('Vars:', sess.run(var))
            print('Sample:', sess.run(cat.sample()))
            sess.run(train_op)
Пример #5
0
    def __call__(self, inputs, state):
        input_ = inputs
        input_state = tf.concat([input_, state], 1)
        update, reset = self.gates(input_state)

        q_logits = tf.layers.dense(input_, self._num_modules)

        ctrl_train = tfd.RelaxedOneHotCategorical(self._softmax_temperature,
                                                  q_logits).sample
        ctrl_test = lambda: one_hot_categorical_mode(q_logits)
        ctrl = tf.cond(self._is_training, ctrl_train, ctrl_test)

        candidate = self.candidate(tf.concat([input_, reset * state], 1), ctrl)

        new_state = update * state + (1 - update) * candidate

        return (new_state, q_logits), new_state
Пример #6
0
    def map_value(self, x, evidence, sample=False):
        child_values = []
        child_probs = []
        for child in self.inputs:
            v, p = child.map_value(x, evidence, sample)
            child_values.append(v)
            child_probs.append(p)
        child_prob_tensor = tf.expand_dims(self.reduction(child_probs), 2)
        child_values_tensor = self.reduction(child_values)
        weights_added = (self.structure + self.weights) + child_prob_tensor
        # probs = tf.reduce_logsumexp(weights_added, 1, keep_dims=False)
        if not sample:
            probs = tf.reduce_max(weights_added, 1, keep_dims=False)
            max_idx = tf.cast(tf.argmax(weights_added, axis=1), tf.int32)
            n = tf.shape(child_values_tensor)[0]
            idx_x = tf.tile(tf.range(0, n, 1), [self.out_dim])
            idx_x = tf.transpose(tf.reshape(idx_x, [self.out_dim, n]))
            values = tf.gather_nd(child_values_tensor,
                                  tf.stack([idx_x, max_idx], axis=2))
        else:
            weights_added_tr = tf.cast(tf.transpose(weights_added, [0, 2, 1]),
                                       tf.float32)
            # cat_dists = dist.Categorical(logits=weights_added)
            relaxed_dists = dist.RelaxedOneHotCategorical(
                SumLayer.SAMPLING_RELAXATION, logits=weights_added_tr)
            idxs = relaxed_dists.sample()
            idxs = tf.cast(idxs, spn_type)
            # print(idxs, child_values_tensor)
            values = tf.matmul(idxs, child_values_tensor)
            # FIXME find better way to prevent nans from log(0)
            probs = tf.log(idxs + 1e-20) + tf.cast(weights_added_tr, spn_type)
            probs = tf.reduce_logsumexp(probs, axis=2)
            # max_idx = cat_dists.sample()
            # probs = cat_dists.log_prob(max_idx)

        return values, probs