def _build(self, inputs): locs = self._modules['locs'](inputs) log_scales = self._modules['scales'](inputs) logits = self._modules['logits'](inputs) scales = tf.nn.softplus(log_scales + softplus_inverse(1.0)) locs = tf.reshape(locs, [-1, self.k, self.ndim]) scales = tf.reshape(scales, [-1, self.k, self.ndim]) logits = tf.reshape(logits, [-1, self.k]) # reshape so that the first dim is the mixture, because we are doing to unstack them # also swap the batch size and the ones that come from the steps of this run # (K x N x D) mix_first_locs = tf.transpose(locs, [1, 0, 2]) mix_first_scales = tf.transpose(scales, [1, 0, 2]) outs = {'locs': locs, 'scales': scales, 'logits': logits} outs['flattened'] = flatten_mdn(logits, locs, scales, self.FLAGS) cat = tfd.Categorical(logits=logits) components = [] eval_components = [] for loc, scale in zip(tf.unstack(mix_first_locs), tf.unstack(mix_first_scales)): normal = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) components.append(normal) eval_normal = tfd.MultivariateNormalDiag(loc=loc[..., :2], scale_diag=scale[..., :2]) eval_components.append(eval_normal) mixture = tfd.Mixture(cat=cat, components=components) eval_cat = tfd.Categorical(logits=logits) eval_mixture = tfd.Mixture(cat=eval_cat, components=eval_components) outs['mixture'] = mixture outs['eval_mixture'] = eval_mixture return outs
def mdn_head(h, FLAGS): with tf.variable_scope('mdn'): locs = tf.reshape(tf.layers.dense(h, 2 * FLAGS['k'], activation=None), [-1, FLAGS['k'], 2]) scales = tf.reshape( tf.layers.dense(h, 2 * FLAGS['k'], activation=tf.exp), [-1, FLAGS['k'], 2]) logits = tf.layers.dense(h, FLAGS['k'], activation=None) cat = tfd.Categorical(logits=logits) components = [] eval_components = [] for loc, scale in zip(tf.unstack(tf.transpose(locs, [1, 0, 2])), tf.unstack(tf.transpose(scales, [1, 0, 2]))): # TODO: does this need to be a more complex distribution? normal = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) components.append(normal) eval_normal = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) eval_components.append(eval_normal) mixture = tfd.Mixture(cat=cat, components=components) eval_mixture = tfd.Mixture(cat=cat, components=eval_components) return { 'locs': locs, 'scales': scales, 'logits': logits, 'mixture': mixture, 'eval_mixture': eval_mixture }
def mdn_loss_func(y_true, y_pred): # Reshape inputs in case this is used in a TimeDistribued layer y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') # Split the inputs into paramaters out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[ num_mixes * output_dim, num_mixes * output_dim, num_mixes ], axis=-1, name='mdn_coef_split') # Construct the mixture models cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mixes mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) loss = mixture.log_prob(y_true) loss = tf.negative(loss) loss = tf.reduce_mean(loss) return loss
def get_distributions_from_tensor(t, dimension, num_mixes): y_pred = tf.reshape(t, [-1, (2 * num_mixes * dimension + 1) + num_mixes], name='reshape_ypreds') out_e, out_pi, out_mus, out_stds = tf.split(y_pred, num_or_size_splits=[ 1, num_mixes, num_mixes * dimension, num_mixes * dimension ], name='mdn_coef_split', axis=-1) cat = tfd.Categorical(logits=out_pi) components_splits = [dimension] * num_mixes mus = tf.split(out_mus, num_or_size_splits=components_splits, axis=1) stds = tf.split(out_stds, num_or_size_splits=components_splits, axis=1) components = [ tfd.MultivariateNormalDiag(loc=mu_i, scale_diag=std_i) for mu_i, std_i in zip(mus, stds) ] mix = tfd.Mixture(cat=cat, components=components) stroke = tfd.Bernoulli(logits=out_e) return mix, stroke
def mixture_sampling(logit): """ Args: - logit: [B, 2 * out_dim * num_mix + num_mix] Returns: - sample: [B, out_dim] """ mean, logit_kappa, logit_pi = tf.split( logit, num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix], axis=-1, name='mix_ivm_coeff_split_sampling') mean = tf.reshape(mean, [-1, num_mix, out_dim]) logit_kappa = tf.reshape(logit_kappa, [-1, num_mix, out_dim]) kappa = tf.math.softplus(logit_kappa) logit_pi = tf.reshape(logit_pi, [-1, num_mix]) means = tf.unstack(mean, axis=1) kappas = tf.unstack(kappa, axis=1) mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi), components=[ tfd.Independent(distribution=tfd.VonMises( loc=loc, concentration=scale), reinterpreted_batch_ndims=1) for loc, scale in zip(means, kappas) ]) sample = mixture.sample() return sample
def gmm_nll_cost(samples, vec_mus, vec_scales, mixing_coeffs, sample_valid, is_constant_scale=False): n_comp = mixing_coeffs.get_shape().as_list()[1] mus = tf.split(vec_mus, num_or_size_splits=n_comp, axis=1) scales = tf.split(vec_scales, num_or_size_splits=n_comp, axis=1) if is_constant_scale: gmm_comps = [ tfd.MultivariateNormalDiag(loc=mu, scale_diag=0 * scale + 0.1) for mu, scale in zip(mus, scales) ] else: gmm_comps = [ tfd.MultivariateNormalDiag(loc=mu, scale_diag=scale) for mu, scale in zip(mus, scales) ] gmm = tfd.Mixture(cat=tfd.Categorical(probs=mixing_coeffs), components=gmm_comps) loss = gmm.log_prob(samples) loss = tf.expand_dims(loss, axis=1) loss = tf.negative(tf.reduce_sum(tf.multiply(loss, sample_valid))) loss = tf.divide(loss, tf.reduce_sum(sample_valid)) return loss
def get_dist(self, y_pred): """turns an output into a distribution. Literally use this as you'd use the normal keras predict. Args: y_pred: nn output Returns: a probability distribution over outputs, for each input. """ num_mix = self.num_mix output_dim = self.output_dim y_pred = tf.reshape(y_pred, [-1, (2 * num_mix * output_dim) + num_mix], name='reshape_ypreds') out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[ num_mix * output_dim, num_mix * output_dim, num_mix ], axis=1, name='mdn_coef_split') cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mix mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] return tfd.Mixture(cat=cat, components=coll)
def mse_func(y_true, y_pred): # Reshape inputs in case this is used in a TimeDistribued layer y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[ num_mixes * output_dim, num_mixes * output_dim, num_mixes ], axis=1, name='mdn_coef_split') cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mixes mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) samp = mixture.sample() mse = tf.reduce_mean(tf.square(samp - y_true), axis=-1) # Todo: temperature adjustment for sampling functon. return mse
def _base_dist(self, p: TensorLike, distributions: List[RandomVariable], *args, **kwargs): return tfd.Mixture( cat=pm.Categorical(p=p, name="MixtureCategories")._distribution, components=[d._distribution for d in distributions], name=kwargs.get("name"), )
def mixture_sampling(logit): """ Args: - logit: [B, 2 * out_dim * num_mix + num_mix] Returns: - sample: [B, out_dim] """ mean, logit_std, logit_pi = tf.split( logit, num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix], axis=-1, name='mix_gaussian_coeff_split_sampling') mean = tf.reshape(mean, [-1, num_mix, out_dim]) logit_std = tf.reshape(tf.maximum(logit_std, log_scale_min_gauss), [-1, num_mix, out_dim]) std = tf.math.softplus(logit_std) logit_pi = tf.reshape(logit_pi, [-1, num_mix]) if use_tfp: means = tf.unstack(mean, axis=1) stds = tf.unstack(std, axis=1) mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi), components=[ tfd.MultivariateNormalDiag( loc=loc, scale_diag=scale) for loc, scale in zip(means, stds) ]) sample = mixture.sample() else: # sample mixture distribution from softmax-ed pi # see https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/ batch_size, _ = tf.shape(logit_pi) u = tf.random.uniform(tf.shape(logit_pi), minval=1e-5, maxval=1. - 1e-5) argmax = tf.argmax(logit_pi - tf.math.log(-tf.math.log(u)), axis=-1) onehot = tf.expand_dims(tf.one_hot(argmax, depth=num_mix, dtype=tf.float32), axis=-1) # [B, num_mix, 1] # sample from selected gaussian # NOTE: we use softplus() to resacale the logit_std since exp() causes explosion of std values u = tf.random.uniform([batch_size, out_dim], minval=1e-5, maxval=1. - 1e-5) mean = tf.reduce_sum(tf.multiply(mean, onehot), axis=1) # [B, out_dim] # TODO: cap std like np.maximum(logit_std, 1) logit_std = tf.reduce_sum(tf.multiply(logit_std, onehot), axis=1) sample = mean + tf.math.softplus(logit_std) * u # clip sample to [-pi, pi]? return sample
def mixture_loss(y_true, logit, mask): """ Args: - y_true: [B, L, out_dim] - logit: [B, L, 2 * out_dim * num_mix + num_mix] - mask: [B, L] Return: - loss """ batch_size, time_step, _ = tf.shape(y_true) mean, logit_kappa, logit_pi = tf.split( logit, num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix], axis=-1, name='mix_ivm_coeff_split') mask = tf.reshape(mask, [-1]) # [B*L] mean = tf.reshape(mean, [-1, num_mix, out_dim]) # [B*L, num_mix, out_dim] logit_kappa = tf.reshape( logit_kappa, [-1, num_mix, out_dim]) # [B*L, num_mix, out_dim] logit_pi = tf.reshape(logit_pi, [-1, num_mix]) # [B*L, num_mix] # rescale parameters kappa = tf.math.softplus(logit_kappa) if use_tfp: y_true = tf.reshape(y_true, [-1, out_dim]) means = tf.unstack(mean, axis=1) kappas = tf.unstack(kappa, axis=1) mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi), components=[ tfd.Independent( distribution=tfd.VonMises( loc=loc, concentration=scale), reinterpreted_batch_ndims=1) for loc, scale in zip(means, kappas) ]) loss = -mixture.log_prob(y_true) else: y_true = tf.reshape(y_true, [-1, 1, out_dim]) # [B*L, 1, out_dim] cos_diff = tf.cos(y_true - mean) log_probs = tf.reduce_sum( -LOGTWOPI - (tf.math.log(tf.math.bessel_i0e(kappa)) + kappa) + cos_diff * kappa, axis=-1) mixed_log_probs = log_probs + tf.nn.log_softmax(logit_pi, axis=-1) loss = -tf.reduce_logsumexp(mixed_log_probs, axis=-1) loss = tf.multiply(loss, mask, name='masking') if reduce: return tf.reduce_sum(loss) else: return tf.reshape(loss, [batch_size, time_step])
def build_prior_network(self, voxel_ae=None): if voxel_ae is None: self.voxel_ae = VoxelAE() self.voxel_ae.build_voxel_ae_enc() # Get the voxel ae variables to restore the voxel ae before # creating new variables for grasp network. self.voxel_ae_vars = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES) else: self.voxel_ae = voxel_ae with tf.variable_scope(self.scope_name + '_struct'): voxel_obj_size_concat = tf.concat( axis=1, values=[ self.voxel_ae.ae_struct_res['embedding'], self.holder_obj_size ]) prior_fc_1 = tf.layers.dense(voxel_obj_size_concat, 128) # prior_fc_1 = tf.layers.batch_normalization(prior_fc_1, training=False) prior_fc_1 = tf.contrib.layers.layer_norm(prior_fc_1) prior_fc_1 = tf.nn.relu(prior_fc_1) prior_fc_2 = tf.layers.dense(prior_fc_1, 32) # prior_fc_2 = tf.layers.batch_normalization(prior_fc_2, training=False) prior_fc_2 = tf.contrib.layers.layer_norm(prior_fc_2) prior_fc_2 = tf.nn.relu(prior_fc_2) locs = tf.layers.dense(prior_fc_2, self.num_components * self.config_dim, activation=None) scales = tf.layers.dense(prior_fc_2, self.num_components * self.config_dim, activation=tf.exp) logits = tf.layers.dense(prior_fc_2, self.num_components, activation=None) # code from Mat's MDN locs = tf.reshape(locs, [-1, self.num_components, self.config_dim]) scales = tf.reshape(scales, [-1, self.num_components, self.config_dim]) logits = tf.reshape(logits, [-1, self.num_components]) self.prior_net_res['locs'] = locs self.prior_net_res['scales'] = scales self.prior_net_res['logits'] = logits # reshape so that the first dim is the mixture, because we are doing to unstack them # also swap the batch size and the ones that come from the steps of this run # (K x N x D) mix_first_locs = tf.transpose(locs, [1, 0, 2]) mix_first_scales = tf.transpose(scales, [1, 0, 2]) cat = tfd.Categorical(logits=logits) components = [] for loc, scale in zip(tf.unstack(mix_first_locs), tf.unstack(mix_first_scales)): normal = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) components.append(normal) mixture = tfd.Mixture(cat=cat, components=components) self.prior_net_res['mixture'] = mixture
def __init__(self, means, stds, pis): self.means = means self.stds = stds self.pis = pis self.dist = tfd.Mixture(cat=tfd.Categorical(probs=pis), components=[ tfd.MultivariateNormalDiag(loc=m, scale_diag=s) for m, s in zip(means, stds) ])
def _init_distribution(conditions, **kwargs): return tfd.Mixture( cat=tfd.Categorical( probs=[1.0 - conditions["psi"], conditions["psi"]]), components=[ tfd.Deterministic(loc=tf.zeros_like(conditions["theta"])), tfd.Poisson(rate=conditions["theta"]), ], **kwargs, )
def mix(gamma, eta, loc, scale, neg_inf, n): return tfd.Mixture( cat=tfd.Categorical(probs=tf.stack([gamma, 1 - gamma], axis=-1)), components=[ tfd.Sample(tfd.Normal(np.float64(neg_inf), 1e-5), sample_shape=n), tfd.Sample(tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale)), sample_shape=n) ])
def _init_distribution(conditions, **kwargs): return tfd.Mixture( cat=tfd.Categorical( probs=[1 - conditions["psi"], conditions["psi"]]), components=[ tfd.Deterministic(loc=tf.zeros_like(conditions["n"])), tfd.Binomial(total_count=conditions["n"], probs=conditions["p"]), ], **kwargs, )
def get_gmm(self, vec_mus, vec_scales, mixing_coeffs): n_comp = mixing_coeffs.get_shape().as_list()[1] mus = tf.split(vec_mus, num_or_size_splits=n_comp, axis=1) scales = tf.split(vec_scales, num_or_size_splits=n_comp, axis=1) gmm_comps = [ tfd.MultivariateNormalDiag(loc=mu, scale_diag=scale) for mu, scale in zip(mus, scales) ] gmm = tfd.Mixture(cat=tfd.Categorical(probs=mixing_coeffs), components=gmm_comps) return gmm
def mix(gamma, eta, loc, scale, neg_inf): _gamma = gamma[..., tf.newaxis] # FIXME: Possible to use tfd.Blockwise? return tfd.Mixture( cat=tfd.Categorical(probs=tf.concat([_gamma, 1 - _gamma], axis=-1)), components=[ tfd.Deterministic(np.float64(neg_inf)), tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale)), ])
def gmm_likelihood_simplex(samples, odim): mu_s = simplex_coordinates(odim) scale = np.ones(odim, dtype=np.float32) * .25 ngmm = odim + 1 mixing_coeffs = np.ones(ngmm, dtype=np.float32) / ngmm gmm_comps = [ tfd.MultivariateNormalDiag(loc=mu, scale_diag=scale) for mu in mu_s ] gmm = tfd.Mixture(cat=tfd.Categorical(probs=mixing_coeffs), components=gmm_comps) loss = gmm.prob(samples) return loss, gmm
def _init_distribution(conditions, **kwargs): return tfd.Mixture( cat=tfd.Categorical( probs=[1.0 - conditions["psi"], conditions["psi"]]), components=[ tfd.Deterministic(loc=tf.zeros_like(conditions["mu"])), tfd.NegativeBinomial( total_count=conditions["alpha"], probs=(conditions["mu"]) / (conditions["mu"] + conditions["alpha"]), ), ], **kwargs, )
def _base_dist(self, *args, **kwargs): """ Zero-inflated Poisson base distribution. A ZeroInflatedPoisson is a mixture between a deterministic distribution and a Poisson distribution. """ mix = kwargs.pop("mix") return tfd.Mixture( cat=tfd.Categorical(probs=[mix, 1.0 - mix]), components=[tfd.Deterministic(0.0), tfd.Poisson(*args, **kwargs)], name="ZeroInflatedPoisson", )
def rnn_sim(rnn: MDNRNN, z, states, a): # Make one LSTM step z = tf.reshape(tf.cast(z, dtype=tf.float32), (1, 1, rnn.args.z_size)) a = tf.reshape(tf.cast(a, dtype=tf.float32), (1, 1, rnn.args.a_width)) input_x = tf.concat((z, a), axis=2) rnn_out, h, c = rnn.inference_base(input_x, initial_state=states) rnn_state = [h, c] rnn_out = tf.reshape(rnn_out, [-1, rnn.args.rnn_size]) # Predict z mdnrnn_params = rnn.predict_z(rnn_out) mdnrnn_params = tf.reshape(mdnrnn_params, [-1, 3 * rnn.args.rnn_num_mixture]) mu, logstd, logpi = tf.split(mdnrnn_params, num_or_size_splits=3, axis=1) logpi = logpi - tf.reduce_logsumexp( input_tensor=logpi, axis=1, keepdims=True) # normalize cat = tfd.Categorical(logits=logpi) component_splits = [1] * rnn.args.rnn_num_mixture mus = tf.split(mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(tf.exp(logstd), num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) z = tf.reshape(mixture.sample(), shape=(-1, rnn.args.z_size)) # Predict done if rnn.args.rnn_predict_done: d_distr = rnn.predict_done(rnn_out) done_logit = tfd.Normal(d_distr[0][0], d_distr[0][1]).sample() done_dist = tfd.Binomial(total_count=1, logits=done_logit) done = tf.squeeze(done_dist.sample()) == 1.0 else: done = False # Predict reward if rnn.args.rnn_predict_reward: r_distr = rnn.predict_reward(rnn_out) reward = tfd.Normal(r_distr[0][0], r_distr[0][1]).sample() else: reward = 1.0 return rnn_state, z, reward, done
def _base_dist(self, *args, **kwargs): """ Zero-inflated negative binomial base distribution. A ZeroInflatedNegativeBinomial is a mixture between a deterministic distribution and a NegativeBinomial distribution. """ mix = kwargs.pop("mix") return tfd.Mixture( cat=tfd.Categorical(probs=[mix, 1.0 - mix]), components=[ tfd.Deterministic(0.0), tfd.NegativeBinomial(*args, **kwargs) ], name="ZeroInflatedNegativeBinomial", )
def failure_cost(samples, vec_mus, mixing_coeffs, sample_invalid, neg_scale=0.1): n_comp = mixing_coeffs.get_shape().as_list()[1] n_dim = samples.get_shape().as_list()[1] mus = tf.split(vec_mus, num_or_size_splits=n_comp, axis=1) smat = tf.ones(shape=(1, n_dim)) * neg_scale gmm_comps = [ tfd.MultivariateNormalDiag(loc=mu, scale_diag=smat) for mu in mus ] gmm = tfd.Mixture(cat=tfd.Categorical(probs=mixing_coeffs), components=gmm_comps) loss = gmm.log_prob(samples) loss = tf.reduce_sum(tf.multiply(loss, sample_invalid)) return loss
def mix(): # Create a mixture of two Gaussians: tfd = tfp.distributions mix = 0.3 bimix_gauss = tfd.Mixture( cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ tfd.Normal(loc=-1., scale=0.1), tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. import matplotlib.pyplot as plt with tf.Session() as sess: x = tf.linspace(-2., 3., int(1e4)) x_ = sess.run(x) prob = sess.run(bimix_gauss.prob(x)) plt.plot(x_, prob) plt.show() pass
def rnn_sim(rnn, z, states, a, training=True): z = tf.reshape(tf.cast(z, dtype=tf.float32), (1, 1, rnn.args.z_size)) a = tf.reshape(tf.cast(a, dtype=tf.float32), (1, 1, rnn.args.a_width)) input_x = tf.concat((z, a), axis=2) rnn_out, h, c = rnn.inference_base( input_x, initial_state=states, training=training) # set training True to use Dropout rnn_state = [h, c] rnn_out = tf.reshape(rnn_out, [-1, rnn.args.rnn_size]) out = rnn.out_net(rnn_out) mdnrnn_params, r, d_logits = rnn.parse_rnn_out(out) mdnrnn_params = tf.reshape(mdnrnn_params, [-1, 3 * rnn.args.rnn_num_mixture]) mu, logstd, logpi = tf.split(mdnrnn_params, num_or_size_splits=3, axis=1) logpi = logpi / rnn.args.rnn_temperature # temperature logpi = logpi - tf.reduce_logsumexp( input_tensor=logpi, axis=1, keepdims=True) # normalize d_dist = tfd.Binomial(total_count=1, logits=d_logits) d = tf.squeeze(d_dist.sample()) == 1.0 cat = tfd.Categorical(logits=logpi) component_splits = [1] * rnn.args.rnn_num_mixture mus = tf.split(mu, num_or_size_splits=component_splits, axis=1) # temperature sigs = tf.split(tf.exp(logstd) * tf.sqrt(rnn.args.rnn_temperature), component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) z = tf.reshape(mixture.sample(), shape=(-1, rnn.args.z_size)) if rnn.args.rnn_r_pred == 0: r = 1.0 # For Doom Reward is always 1.0 if the agent is alive return rnn_state, z, r, d
def rnn_sim(rnn, z, states, a): if rnn.args.env_name == 'CarRacing-v0': raise ValueError('Not implemented yet for CarRacing') z = tf.reshape(tf.cast(z, dtype=tf.float32), (1, 1, rnn.args.z_size)) a = tf.reshape(tf.cast(a, dtype=tf.float32), (1, 1, rnn.args.a_width)) input_x = tf.concat((z, a), axis=2) rnn_out, h, c = rnn.inference_base(input_x, initial_state=states) rnn_state = [h, c] rnn_out = tf.reshape(rnn_out, [-1, rnn.args.rnn_size]) out = rnn.out_net(rnn_out) mdnrnn_params, d_logits = out[:, :-1], out[:, -1:] mdnrnn_params = tf.reshape(mdnrnn_params, [-1, 3 * rnn.args.rnn_num_mixture]) mu, logstd, logpi = tf.split(mdnrnn_params, num_or_size_splits=3, axis=1) logpi = logpi - tf.reduce_logsumexp( input_tensor=logpi, axis=1, keepdims=True) # normalize d_dist = tfd.Binomial(total_count=1, logits=d_logits) d = tf.squeeze(d_dist.sample()) == 1.0 cat = tfd.Categorical(logits=logpi) component_splits = [1] * rnn.args.rnn_num_mixture mus = tf.split(mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(tf.exp(logstd), num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) z = tf.reshape(mixture.sample(), shape=(-1, rnn.args.z_size)) r = 1.0 # For Doom Reward is always 1.0 if the agent is alive return rnn_state, z, r, d
def _init_distribution(conditions, **kwargs): p, d = conditions["p"], conditions["distributions"] # if 'd' is a sequence of pymc distributions, then use the underlying # tfp distributions for the mixture if isinstance(d, collections.abc.Sequence): if any(not isinstance(el, Distribution) for el in d): raise TypeError( "every element in 'distribution' needs to be a pymc4.Distribution object" ) distr = [el._distribution for el in d] return tfd.Mixture( tfd.Categorical(probs=p, **kwargs), distr, **kwargs, use_static_graph=True ) # else if 'd' is a pymc distribution with batch_size > 1 elif isinstance(d, Distribution): return tfd.MixtureSameFamily( tfd.Categorical(probs=p, **kwargs), d._distribution, **kwargs ) else: raise TypeError( "'distribution' needs to be a pymc4.Distribution object or a sequence of distributions" )
def sampling_func(y_pred): # Reshape inputs in case this is used in a TimeDistribued layer y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes]) out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[ num_mixes * output_dim, num_mixes * output_dim, num_mixes ], axis=1) cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mixes mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) samp = mixture.sample() # Todo: temperature adjustment for sampling function. return samp
def mdn_loss_func(y_true, y_pred): # Split the inputs into parameters, 1 for end-of-stroke, `num_mixes` # for other # y_true = tf.reshape(tensor=y_true, shape=y_pred.shape) y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim + 1) + num_mixes], name='reshape_ypreds') y_true = tf.reshape(y_true, [-1, output_dim + 1], name='reshape_ytrue') out_e, out_pi, out_mus, out_stds = tf.split(y_pred, num_or_size_splits=[ 1, num_mixes, num_mixes * output_dim, num_mixes * output_dim ], name='mdn_coef_split', axis=-1) cat = tfd.Categorical(logits=out_pi) components_splits = [output_dim] * num_mixes mus = tf.split(out_mus, num_or_size_splits=components_splits, axis=1) stds = tf.split(out_stds, num_or_size_splits=components_splits, axis=1) components = [ tfd.MultivariateNormalDiag(loc=mu_i, scale_diag=std_i) for mu_i, std_i in zip(mus, stds) ] mix = tfd.Mixture(cat=cat, components=components) xs, ys, es = tf.unstack(y_true, axis=-1) X = tf.stack((xs, ys), axis=-1) stroke = tfd.Bernoulli(logits=out_e) loss1 = tf.negative(mix.log_prob(X)) loss2 = tf.negative(stroke.log_prob(es)) loss = tf.add(loss1, loss2) loss = tf.reduce_mean(loss) return loss