def testMaximumLikelihoodTraining(self): # Test Maximum Likelihood training with default bijector. with self.test_session() as sess: base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) batch_norm = BatchNormalization(training=True) dist = transformed_distribution_lib.TransformedDistribution( distribution=base_dist, bijector=batch_norm) target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.]) target_samples = target_dist.sample(100) dist_samples = dist.sample(3000) loss = -math_ops.reduce_mean(dist.log_prob(target_samples)) with ops.control_dependencies(batch_norm.batchnorm.updates): train_op = adam.AdamOptimizer(1e-2).minimize(loss) moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) variables.global_variables_initializer().run() for _ in range(3000): sess.run(train_op) [ dist_samples_, moving_mean_, moving_var_ ] = sess.run([ dist_samples, moving_mean, moving_var ]) self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2) self.assertAllClose([1., 2.], moving_mean_, atol=5e-2) self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
def build_policy_network_op(self, scope="policy"): action_means = [] with tf.variable_scope(scope): for i in range(self.num_models): with tf.variable_scope('policy_%s' % i): if self.discrete: raise NotImplementedError else: action_means.append(dnn(input=self.obv_ph, output_size=self.act_dim, scope='dnn', n_layers=c.pg_pi_n_layers, size=c.pg_pi_hidden_fc_size)) action_mean_tr = tf.stack(action_means, axis=1) self.weighted_action_means = tf.reduce_sum( tf.expand_dims(self.belief_prob_target, axis=2) * action_mean_tr, axis=1) act_log_std = tf.get_variable('act_log_std', shape=[self.act_dim], trainable=True) self.action_std = tf.exp(act_log_std) multivariate = tfd.MultivariateNormalDiag( loc=self.weighted_action_means, scale_diag=self.action_std) self.sampled_action = tf.random_normal( [self.act_dim]) * self.action_std + self.weighted_action_means self.logprob = multivariate.log_prob(self.act_ph) multivariate1 = tfd.MultivariateNormalDiag( loc=self.weighted_action_means_ph1, scale_diag=self.action_std_ph1) multivariate2 = tfd.MultivariateNormalDiag( loc=self.weighted_action_means_ph2, scale_diag=self.action_std_ph2) self.kl_divergence = tf.reduce_mean( tf.distributions.kl_divergence(multivariate1, multivariate2))
def tf_gen_sample(self, values, var): if (self.dist == 'norm'): std = tf.sqrt(var + 1e-4) self.ref_dis = ds.MultivariateNormalDiag(loc=values, scale_diag=std) self.ref_tile_dis = ds.MultivariateNormalDiag( loc=tf.tile(values[:, None, :], [1, self._kernel_n_particles, 1]), scale_diag=tf.tile(std[:, None, :], [1, self._kernel_n_particles, 1])) return tf.transpose(self.ref_dis.sample(self._kernel_n_particles), [1, 0, 2]) elif (self.dist == 'beta'): var = var + 1e-4 a = values * (values * (1 - values) / var - 1) b = (1 - values) * (values * (1 - values) / var - 1) masked = (var < values * (1 - values)) a = tf.where(masked, a, a * 0 + 2) b = tf.where(masked, b, b * 0 + 2) self.a = a self.b = b self.values = values self.var = var self.ref_dis_beta = ds.Beta(a, b) self.ref_tile_dis_beta = ds.Beta( tf.tile(a[:, None, :], [1, self._kernel_n_particles, 1]), tf.tile(b[:, None, :], [1, self._kernel_n_particles, 1])) return (tf.transpose( self.ref_dis_beta.sample(self._kernel_n_particles), [1, 0, 2]) * 2 - 1) * 0.99 else: print('methods has not been implemented') return None
def kl_loss(X_true, X_predict): latent_prior = dist.MultivariateNormalDiag( [0.] * latent_dimensions, [1.] * latent_dimensions) approximate_posterior = dist.MultivariateNormalDiag( z_mu, K.sqrt(K.exp(z_ls2))) return { 'kl_loss': K.mean(dist.kl(latent_prior, approximate_posterior)) }
def ecg_loss(y_true, y_pred): num_dims = y_pred.get_shape().as_list()[1] n = tfd.MultivariateNormalDiag( loc=y_pred, scale_diag=tf.ones(num_dims)).prob(y_true) nc = tfd.MultivariateNormalDiag(loc=y_pred, scale_diag=tf.ones(num_dims) * ecg_c).prob(y_true) return tf.reduce_mean( tf.log((1.0 - ecg_epsilon) * n + ecg_epsilon * nc) * -1.0)
def vae_loss(X_true, X_predict): xent_loss = K.sum( 0.5 * x_ls2 + (tf.square(x - x_mu) / (2.0 * tf.exp(x_ls2))), 1) latent_prior = dist.MultivariateNormalDiag( [0.] * latent_dimensions, [1.] * latent_dimensions) approximate_posterior = dist.MultivariateNormalDiag( z_mu, K.sqrt(K.exp(z_ls2))) kl_loss = dist.kl(latent_prior, approximate_posterior) return xent_loss + kl_loss
def generative_model(observations, samples, is_training, latent_layer_dims, nn_layers): samples = list(reversed(samples)) latent_layer_dims = list(reversed(latent_layer_dims)) mu, sigma_sq = generator_net(samples[0], is_training, nn_layers[0], nn_layers[1], latent_layer_dims[0], 'gaussian') mean_list = [mu] var_list = [sigma_sq] p_lls = [] p_gen = None # reconstruction of training samples for i in range(1, len(samples) - 1): mu, sigma_sq = generator_net(samples[i], is_training, nn_layers[0], nn_layers[1], latent_layer_dims[i], 'gaussian') p_lls.append(dist.MultivariateNormalDiag(mu, sigma_sq)) mean_list.append(mu) var_list.append(sigma_sq) probs = generator_net(samples[-1], is_training, nn_layers[0], nn_layers[1], observations.get_shape().as_list()[1], likelihood='bernoulli') p_x = bernoulli_log_likelihood(observations, probs) # generation of novel samples sample_gen = tf.random_uniform([16], maxval=11, dtype=tf.int32) sample_gen = tf.one_hot(sample_gen, 10) mu_gen, sigma_sq_gen = generator_net(sample_gen, is_training, nn_layers[0], nn_layers[1], latent_layer_dims[0], 'gaussian') gen_samples = [dist.MultivariateNormalDiag(mu_gen, sigma_sq_gen).sample()] for i in range(1, len(latent_layer_dims) - 1): mu, sigma_sq = generator_net(samples[i], is_training, nn_layers[0], nn_layers[1], latent_layer_dims[i], 'gaussian') gen_samples.append(dist.MultivariateNormalDiag(mu, sigma_sq).sample()) probs = generator_net(gen_samples[-1], is_training, nn_layers[0], nn_layers[1], observations.get_shape().as_list()[1], likelihood='bernoulli') p_gen = dist.Bernoulli(probs=probs).sample() return probs, p_gen, p_x, mean_list, var_list
def test_reparameterized_stochastic_connector(self): """Tests the logic of :class:`~texar.modules.ReparameterizedStochasticConnector`. """ state_size = (10, 10) variable_size = 100 state_size_ts = (tf.TensorShape([10, 10]), tf.TensorShape([2, 3, 4])) sample_num = 10 mu = tf.zeros([self._batch_size, variable_size]) var = tf.ones([self._batch_size, variable_size]) mu_vec = tf.zeros([variable_size]) var_vec = tf.ones([variable_size]) gauss_ds = tfds.MultivariateNormalDiag(loc=mu, scale_diag=var) gauss_ds_vec = tfds.MultivariateNormalDiag(loc=mu_vec, scale_diag=var_vec) gauss_connector = ReparameterizedStochasticConnector(state_size) gauss_connector_ts = ReparameterizedStochasticConnector(state_size_ts) output_1, _ = gauss_connector(gauss_ds) output_2, _ = gauss_connector(distribution="MultivariateNormalDiag", distribution_kwargs={ "loc": mu, "scale_diag": var }) sample_ts, _ = gauss_connector_ts(gauss_ds) # specify sample num sample_test_num, _ = gauss_connector(gauss_ds_vec, num_samples=sample_num) # test when :attr:`transform` is False #sample_test_no_transform = gauss_connector(gauss_ds, transform=False) test_list = [output_1, output_2, sample_ts, sample_test_num] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) out_list = sess.run(test_list) out1 = out_list[0] out2 = out_list[1] out_ts = out_list[2] out_test_num = out_list[3] # check the same size self.assertEqual(out_test_num[0].shape, tf.TensorShape([sample_num, state_size[0]])) self.assertEqual(out1[0].shape, tf.TensorShape([self._batch_size, state_size[0]])) self.assertEqual(out2[0].shape, tf.TensorShape([self._batch_size, state_size[0]])) _assert_same_size(out_ts, state_size_ts)
def input_tensor(batch_size, return_mixture=False): gaussians = [ ds.MultivariateNormalDiag(loc=(5.0, 5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(-5.0, 5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(-5.0, -5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(5.0, -5.0), scale_diag=(0.5, 0.5)) ] uniform_mixture_probs = [1 / len(gaussians)] * len(gaussians) mixture = ds.Mixture(cat=ds.Categorical(uniform_mixture_probs), components=gaussians) sampled = mixture.sample(batch_size) return (sampled, mixture) if return_mixture else sampled
def build_policy_network_op(self, scope="policy_network"): if self.discrete: action_logits = dnn(input=self.obs_ph, output_size=self.action_dim, scope=scope, n_layers=c.pg_pi_n_layers, size=c.pg_pi_hidden_fc_size) self.sampled_action = tf.squeeze(tf.multinomial(action_logits, 1), axis=1) self.logprob = -tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.action_placeholder, logits=action_logits) else: action_means = dnn(input=self.obs_ph, output_size=self.action_dim, scope=scope, n_layers=c.pg_pi_n_layers, size=c.pg_pi_hidden_fc_size) log_std = tf.get_variable('log_std', shape=[self.action_dim], trainable=True) action_std = tf.exp(log_std) multivariate = tfd.MultivariateNormalDiag(loc=action_means, scale_diag=action_std) self.sampled_action = tf.random_normal( [self.action_dim]) * action_std + action_means self.logprob = multivariate.log_prob(self.action_placeholder)
def radial_gaussians(batch_size, n_mixture=8, std=0.01, radius=1.0, add_far=False): thetas = np.linspace(0, 2 * np.pi, n_mixture + 1)[:-1] xs, ys = radius * np.cos(thetas), radius * np.sin(thetas) cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def detect_objects(self, images): """ Use RNN to detect objects in the images. :param images: image tensor of shape [num_batch, h, w, c] :return: z_pres, a tensor of shape [num_batch, num_steps] with values between 0 and 1, indicating if an object has been found during that step z_where, a tensor of shape [num_batch, num_steps, 4], indicating the position and scale of the objects q_z_where, a tensor of shape [num_batch, num_steps] indicating q(z_where | x) """ batch_size, height, width, channels = [ int(dim) for dim in images.shape ] images_lin = tf.reshape(images, [batch_size, height * width * channels]) # Initial state of the LSTM memory. hidden_state = tf.zeros([batch_size, self.lstm_units]) current_state = tf.zeros([batch_size, self.lstm_units]) state = hidden_state, current_state z = [] for _ in range(self.conf.num_steps): with tf.variable_scope('detection-rnn'): # LSTM variables might get initialized on first call hidden, state = self.lstm(images_lin, state) z.append(self.output_mlp.forward(hidden)) z = tf.stack(z, axis=1, name='z') z_pres = tf.sigmoid(1 * z[:, :, 0], name='z_pres') z_where_mean = tf.sigmoid(z[:, :, 1:5]) z_where_var = 0.3 * tf.sigmoid(z[:, :, 5:]) # z_where_mean consists of [sx, sy/sx, x, y] # scale sx to [0.3, 0.9], sy to [0.75 * sx, 1.25 * sx], # and x, y to [0, 0.9 * canvas_size] obj_scale_delta = self.conf.max_obj_scale - self.conf.min_obj_scale y_scale_delta = self.conf.max_y_scale - self.conf.min_y_scale z_where_mean *= [[[ obj_scale_delta, y_scale_delta, 0.9 * height, 0.9 * width ]]] z_where_mean += [[[ self.conf.min_obj_scale, self.conf.min_y_scale, 0.0, 0.0 ]]] scale_y = z_where_mean[:, :, 0:1] * z_where_mean[:, :, 1:2] z_where_mean = tf.concat( (z_where_mean[:, :, 0:1], scale_y, z_where_mean[:, :, 2:]), axis=2) # sample from variational distribution for z_where z_where_dist = dists.MultivariateNormalDiag(loc=z_where_mean, scale_diag=z_where_var) z_where = z_where_dist.sample() q_z_where = z_where_dist.log_prob(z_where) return z_pres, z_where, q_z_where
def get_posterior(xb, j): with tf.variable_scope('posterior', reuse=tf.AUTO_REUSE): y_ = tf.fill(tf.stack([tf.shape(xb)[0], args.POSTERIOR_NK]), 0.0) y = tf.add(y_, tf.constant(np.eye(args.POSTERIOR_NK)[j], 'float32')) if args.SEPARATE_POSTERIORS == 1: indj = j else: indj = 0 hid1 = FC((xb, y), [args.NH, args.NH], activations[args.ACT], name='hid_%d' % indj) loc = FC(hid1, [args.NZ], [None], name='u_%d' % indj) scale = FC(hid1, [args.NZ], [activations[args.SF]], name='sig_%d' % indj) posterior = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) if args.POSTERIOR_IAF == 0: return posterior, loc bjs = [] for ii in range(args.POSTERIOR_IAF): if args.POSTERIOR_IAF_TYPE == 'affine': bj = tfb.Affine(shift=tf.Variable(tf.zeros(6)), scale_tril=tfd.fill_triangular( tf.Variable(tf.ones(6 * (6 + 1) / 2))), name='aff_%d_%d' % (ii, j)) if args.POSTERIOR_IAF_TYPE == 'iaf': bj = tfb.Invert(tfb.MaskedAutoregressiveFlow( shift_and_log_scale_fn=tfb. masked_autoregressive_default_template(hidden_layers=[ args.POSTERIOR_IAF_NH, args.POSTERIOR_IAF_NH ])), name='floaw_%d_%d' % (ii, j)) if args.POSTERIOR_IAF_TYPE == 'nvp': bj = tfb.RealNVP( num_masked=3, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[ args.POSTERIOR_IAF_NH, args.POSTERIOR_IAF_NH ])) if args.POSTERIOR_IAF_TYPE == 'masked': bj = tfb.MaskedAutoregressiveFlow( shift_and_log_scale_fn=tfb. masked_autoregressive_default_template(hidden_layers=[ args.POSTERIOR_IAF_NH, args.POSTERIOR_IAF_NH ]), name='floaw_%d_%d' % (ii, j)) bjs.append(bj) bj = tfb.Chain(bjs) maf = tfd.TransformedDistribution(distribution=posterior, bijector=bj) return maf, loc
def _make_encoder(data, code_size): with tf.variable_scope('encoder'): x = tf.layers.flatten(data) x = tf.layers.dense(x, 200, tf.nn.relu) x = tf.layers.dense(x, 200, tf.nn.relu) loc = tf.layers.dense(x, code_size) scale = tf.layers.dense(x, code_size, tf.nn.softplus) return tfd.MultivariateNormalDiag(loc, scale)
def build_prior(self): if self.whiten: mvn = tfd.MultivariateNormalDiag(loc=tf.zeros_like(self.Ug)) else: mvn = tfd.MultivariateNormalFullCovariance( loc=tf.zeros_like(self.Ug), covariance_matrix=self.kern.K(self.Zg, self.Zg)) return tf.reduce_sum(mvn.log_prob(self.Ug))
def make_encoder(x, z_dim=8): ''' Encoder: q(z|x) ''' net = make_nn(x, z_dim * 2) return tfd.MultivariateNormalDiag(loc=net[..., :z_dim], scale_diag=tf.nn.softplus(net[..., z_dim:]))
def swiss(batch_size, size=1., std=0.01): x, _ = datasets.make_swiss_roll(1000) norm = x[:, ::2].max() xs = x[:, 0] * size / norm ys = x[:, 2] * size / norm cat = ds.Categorical(tf.zeros(len(x))) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def get_next_states(self, states_actions): outputs = [model.get_prediction(states_actions) for model in self.models] mu = tf.concat([output[0] for output in outputs], axis=-1) sd = tf.concat([output[1] for output in outputs], axis=-1) next_state = tfd.MultivariateNormalDiag(loc=mu, scale_diag=sd).sample() return next_state
def make_encoder(x, z_dim=Z_DIM): ''' Encoder: q(z|x) ''' with tf.variable_scope("encoder"): net = make_nn(x, z_dim * 2) print('encoder net', net) return tfd.MultivariateNormalDiag(loc=net[..., :z_dim], scale_diag=tf.nn.softplus(net[..., z_dim:]))
def make_encoder(data, code_size): x = tf.layers.flatten(data) x = tf.layers.dense(x, hidden, tf.nn.relu) x = tf.layers.dense(x, hidden, tf.nn.relu) loc = tf.layers.dense(x, code_size) scale = tf.layers.dense(x, code_size, tf.nn.softplus) return tfd.MultivariateNormalDiag(loc, scale), loc, scale
def radial_gaussians2(batch_size, n_mixture=8, std=0.01, r1=1.0, r2=2.0): thetas = np.linspace(0, 2 * np.pi, n_mixture + 1)[:-1] x1s, y1s = r1 * np.sin(thetas), r1 * np.cos(thetas) x2s, y2s = r2 * np.sin(thetas), r2 * np.cos(thetas) xs = np.vstack([x1s, x2s]) ys = np.vstack([y1s, y2s]) cat = ds.Categorical(tf.zeros(n_mixture * 2)) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def line_1d(batch_size, n_mixture=5, std=0.01, d=1.0, add_far=False): xs = np.linspace(-d, d, n_mixture, dtype=np.float32) p = [0.] * n_mixture if add_far: xs = np.concatenate([np.array([-10 * d]), xs, np.array([10 * d])], 0) p = [0.] + p + [0.] cat = ds.Categorical(tf.constant(p)) comps = [ds.MultivariateNormalDiag([xi], [std]) for xi in xs.ravel()] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def rect(batch_size, std=0.01, nx=5, ny=5, h=2, w=2): x = np.linspace(- h, h, nx) y = np.linspace(- w, w, ny) p = [] for i in x: for j in y: p.append((i, j)) cat = ds.Categorical(tf.zeros(len(p))) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in p] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def _get_distribution(state, output_size): h = state num_layers = 1 for i in range(num_layers): with tf.variable_scope('linear%d' % i) as scope: h = linear(h, output_size, scope=scope) with tf.variable_scope('Mean'): mean = linear(h, output_size, activation=None) with tf.variable_scope('Var'): var = linear(h, output_size, activation=tf.nn.softplus) return tfd.MultivariateNormalDiag(mean, var)
def ring_mog(batch_size, n_mixture=8, std=0.01, radius=1.0): thetas = np.linspace(0, 2 * np.pi, n_mixture, endpoint=False) xs = radius * np.sin(thetas, dtype=np.float32) ys = radius * np.cos(thetas, dtype=np.float32) cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel()) ] data = ds.Mixture(cat, comps) return data.sample(batch_size), np.stack([xs, ys], axis=1)
def __init__(self, x_ph, log_likelihood_fn, dims, num_samples=16, method='hmc', config=None): """ The model implements Hamiltonian AIS. Developed by @bilginhalil on top of https://github.com/jiamings/ais/ Example use case: logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)} p(x|z, theta) -> likelihood function p(z) -> prior Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix :param x_ph: Placeholder for x :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :) :param num_samples: Number of samples to sample from in order to estimate the likelihood. The following are parameters for HMC. :param stepsize: :param n_steps: :param target_acceptance_rate: :param avg_acceptance_slowness: :param stepsize_min: :param stepsize_max: :param stepsize_dec: :param stepsize_inc: """ self.dims = dims self.log_likelihood_fn = log_likelihood_fn self.num_samples = num_samples self.z_shape = [ dims['batch_size'] * self.num_samples, dims['input_dim'] ] if method != 'riem_ld': self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape), scale_diag=tf.ones( self.z_shape)) else: self.prior = HypersphericalUniform(dims['input_dim'] - 1) self.batch_size = dims['batch_size'] self.x = tf.tile(x_ph, [self.num_samples, 1]) self.method = method self.config = config if config is not None else default_config[method]
def grid_mog(batch_size, n_mixture=25, std=0.05, space=2.0): grid_range = int(np.sqrt(n_mixture)) modes = np.array([ np.array([i, j]) for i, j in itertools.product(range(-grid_range + 1, grid_range, 2), range(-grid_range + 1, grid_range, 2)) ], dtype=np.float32) modes = modes * space / 2. cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ds.MultivariateNormalDiag(mu, [std, std]) for mu in modes] data = ds.Mixture(cat, comps) return data.sample(batch_size), modes
def build_prior(self): if self.kern.ktype == "id" or self.kern.ktype == "kr": if self.whiten: mvn = tfd.MultivariateNormalDiag(loc=tf.zeros_like(self.U[:, 0])) else: mvn = tfd.MultivariateNormalFullCovariance( loc=tf.zeros_like(self.U[:, 0]), covariance_matrix=self.kern.K(self.Z, self.Z)) probs = tf.add_n( [mvn.log_prob(self.U[:, d]) for d in range(self.kern.ndims)]) else: if self.whiten: mvn = tfd.MultivariateNormalDiag(loc=tf.zeros_like(self.U)) else: mvn = tfd.MultivariateNormalFullCovariance( loc=tf.zeros_like(self.U), covariance_matrix=self.kern.K(self.Z, self.Z)) probs = tf.reduce_sum(mvn.log_prob(tf.squeeze(self.U))) return probs
def compute_loss(inf_mean_list, inf_var_list, gen_mean_list, gen_var_list, q_log_discrete, log_px, batch_size): gaussian_div = [] for mean0, var0, mean1, var1 in zip(inf_mean_list, inf_var_list, reversed(gen_mean_list), reversed(gen_var_list)): kl_gauss = dist.kl_divergence(dist.MultivariateNormalDiag(mean0, var0), dist.MultivariateNormalDiag(mean1, var1)) gaussian_div.append(kl_gauss) kl_gauss = tf.reshape(tf.concat(gaussian_div, axis=0), [batch_size, len(gaussian_div)]) kl_dis = dist.kl_divergence( dist.OneHotCategorical(logits=q_log_discrete), dist.OneHotCategorical( logits=tf.log(tf.ones_like(q_log_discrete) * 1 / 10))) mean_KL = tf.reduce_mean(tf.reduce_sum(kl_gauss, axis=1) + kl_dis) mean_rec = tf.reduce_mean(log_px) loss = tf.reduce_mean(log_px - 0.5 * ((tf.reduce_sum(kl_gauss, axis=1) + kl_dis))) return loss, mean_rec, mean_KL
def get_next_states(self, states_actions): self.string = 'unroll2_gns' mu, sigma = [ tf.concat(e, axis=-1) for e in zip(*[ model.posterior_predictive_distribution(states_actions, None) for model in self.models ]) ] self.mus1.append(mu) self.sigmas1.append(sigma) #print mu.shape #print sigma.shape next_state = tfd.MultivariateNormalDiag( loc=mu, scale_diag=tf.sqrt(sigma)).sample() return next_state