def gumbel_reparmeterization(logits_z, tau, rnd_sample=None, hard=True, eps=1e-9): ''' The gumbel-softmax reparameterization ''' latent_size = logits_z.get_shape().as_list()[1] # Prior p_z = d.OneHotCategorical( probs=tf.constant(1.0 / latent_size, shape=[latent_size])) # p_z = d.RelaxedOneHotCategorical(probs=tf.constant(1.0/latent_size, # shape=[latent_size]), # temperature=10.0) # p_z = 1.0 / latent_size # log_p_z = tf.log(p_z + eps) with st.value_type(st.SampleValue()): q_z = st.StochasticTensor( d.RelaxedOneHotCategorical(temperature=tau, logits=logits_z)) q_z_full = st.StochasticTensor(d.OneHotCategorical(logits=logits_z)) reduce_index = [1] if len(logits_z.get_shape().as_list()) == 2 else [1, 2] kl = d.kl(q_z_full.distribution, p_z, allow_nan_stats=False) if len(shp(kl)) > 1: return [q_z, tf.reduce_sum(kl, reduce_index)] else: return [q_z, kl]
def inference_model(observations, is_training, latent_layer_dims, tau, batch_size, nn_layers): samples = [observations] q_log_likelihoods = [] mean_list = [] var_list = [] for i in range(len(latent_layer_dims)): mu, sigma_sq = inference_net(samples[i], is_training, nn_layers[0], nn_layers[1], latent_layer_dims[i]) sample = sample_gaussian(latent_layer_dims[i], mu, sigma_sq, batch_size) sample_ll = gaussian_log_likelihood(sample, mu, sigma_sq) mean_list.append(mu) var_list.append(sigma_sq) q_log_likelihoods.append(sample_ll) samples.append(sample) logits_qy = slim.fully_connected(mlp(samples[-1], is_training, nn_layers[0], nn_layers[1]), 10, activation_fn=None) q_y = dist.RelaxedOneHotCategorical(tau, logits=logits_qy) sample = q_y.sample() samples.append(sample) q_log_likelihoods.append(tf.log(sample + TOL)) samples.remove(samples[0]) return samples, mean_list, var_list, q_log_likelihoods
def sample_q(self, k, mode): if mode == tf.estimator.ModeKeys.TRAIN: z_dist = distributions.RelaxedOneHotCategorical( self.temp, logits=self.q_dist.logits) z_NK = z_dist.sample(k) elif mode == tf.estimator.ModeKeys.EVAL: z_NK = tf.to_float(self.q_dist.sample(k)) return tf.reshape(z_NK, [k, -1, self.N * self.K])
def relaxed_test(): var = tf.Variable([-5, -10], dtype=tf.float32) cat = dist.RelaxedOneHotCategorical(2.0, logits=var) loss = tf.reduce_sum((tf.cast(cat.sample([10]), tf.float32)), axis=0)[0] train_op = tf.train.AdamOptimizer().minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(100): print('Vars:', sess.run(var)) print('Sample:', sess.run(cat.sample())) sess.run(train_op)
def __call__(self, inputs, state): input_ = inputs input_state = tf.concat([input_, state], 1) update, reset = self.gates(input_state) q_logits = tf.layers.dense(input_, self._num_modules) ctrl_train = tfd.RelaxedOneHotCategorical(self._softmax_temperature, q_logits).sample ctrl_test = lambda: one_hot_categorical_mode(q_logits) ctrl = tf.cond(self._is_training, ctrl_train, ctrl_test) candidate = self.candidate(tf.concat([input_, reset * state], 1), ctrl) new_state = update * state + (1 - update) * candidate return (new_state, q_logits), new_state
def map_value(self, x, evidence, sample=False): child_values = [] child_probs = [] for child in self.inputs: v, p = child.map_value(x, evidence, sample) child_values.append(v) child_probs.append(p) child_prob_tensor = tf.expand_dims(self.reduction(child_probs), 2) child_values_tensor = self.reduction(child_values) weights_added = (self.structure + self.weights) + child_prob_tensor # probs = tf.reduce_logsumexp(weights_added, 1, keep_dims=False) if not sample: probs = tf.reduce_max(weights_added, 1, keep_dims=False) max_idx = tf.cast(tf.argmax(weights_added, axis=1), tf.int32) n = tf.shape(child_values_tensor)[0] idx_x = tf.tile(tf.range(0, n, 1), [self.out_dim]) idx_x = tf.transpose(tf.reshape(idx_x, [self.out_dim, n])) values = tf.gather_nd(child_values_tensor, tf.stack([idx_x, max_idx], axis=2)) else: weights_added_tr = tf.cast(tf.transpose(weights_added, [0, 2, 1]), tf.float32) # cat_dists = dist.Categorical(logits=weights_added) relaxed_dists = dist.RelaxedOneHotCategorical( SumLayer.SAMPLING_RELAXATION, logits=weights_added_tr) idxs = relaxed_dists.sample() idxs = tf.cast(idxs, spn_type) # print(idxs, child_values_tensor) values = tf.matmul(idxs, child_values_tensor) # FIXME find better way to prevent nans from log(0) probs = tf.log(idxs + 1e-20) + tf.cast(weights_added_tr, spn_type) probs = tf.reduce_logsumexp(probs, axis=2) # max_idx = cat_dists.sample() # probs = cat_dists.log_prob(max_idx) return values, probs