def _posterior(self, prev_state, prev_action, obs): """Compute posterior state from previous state and current observation.""" prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) inputs = tf.concat([prior['mean'], prior['stddev'], obs], -1) hidden = tf.layers.dense(inputs, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, }
def _posterior(self, prev_state, prev_action, obs): prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) hidden = tf.concat([prior['belief'], obs], -1) for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, 'belief': prior['belief'], 'rnn_state': prior['rnn_state'], }
def approximate_posterior(self, observations, inputs_extra=None, training=False): """Given an observation compute Q(latent | observation).""" with tf.variable_scope("encoder"): code, endpoints = self._encoder(observations, inputs_extra, training=training) del endpoints if inputs_extra is not None: self._extra_shape = inputs_extra.get_shape().as_list() else: self._extra_shape = None self._latent_shape = code.get_shape().as_list() # because code includes means and variances self._latent_shape[-1] = self._latent_shape[-1] // 2 code_flat = tf.reshape(code, [-1, self._latent_shape[-1] * 2]) mean = code_flat[Ellipsis, :self._latent_dimensions] sigma = self._epsilon + tf.nn.softplus( code_flat[Ellipsis, self._latent_dimensions:]) return tfd.MultivariateNormalDiag(loc=mean, scale_diag=sigma)
def mse_func(y_true, y_pred): # Reshape inputs in case this is used in a TimeDistribued layer y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, num_mixes * output_dim, num_mixes], axis=1, name='mdn_coef_split') cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mixes mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs)] mixture = tfd.Mixture(cat=cat, components=coll) samp = mixture.sample() mse = tf.reduce_mean(tf.square(samp - y_true), axis=-1) # Todo: temperature adjustment for sampling functon. return mse
def f_one_step(self, current_state, previous_kernel_results, all_states): old_probs = list() new_state = tf.identity(current_state) # MH kernel_params = (all_states[self.state_indices['kernel_params']][0], all_states[self.state_indices['kernel_params']][1]) m, K = self.fbar_prior_params(*kernel_params) for r in range(self.num_replicates): # Gibbs step fbar = current_state[r] z_i = tfd.MultivariateNormalDiag(fbar, self.step_size).sample() fstar = tf.zeros_like(fbar) for i in range(self.num_tfs): invKsigmaK = tf.matmul( tf.linalg.inv(K[i] + tf.linalg.diag(self.step_size)), K[i]) # (C_i + hI)C_i L = jitter_cholesky(K[i] - tf.matmul(K[i], invKsigmaK)) c_mu = tf.matmul(z_i[i, None], invKsigmaK) fstar_i = tf.matmul( tf.random.normal( (1, L.shape[0]), dtype='float64'), L) + c_mu mask = np.zeros((self.num_tfs, 1), dtype='float64') mask[i] = 1 fstar = (1 - mask) * fstar + mask * fstar_i mask = np.zeros((self.num_replicates, 1, 1), dtype='float64') mask[r] = 1 test_state = (1 - mask) * new_state + mask * fstar new_prob = self.calc_prob_fn(test_state, all_states) old_prob = self.calc_prob_fn(new_state, all_states) #previous_kernel_results.target_log_prob #tf.reduce_sum(old_m_likelihood) + old_f_likelihood is_accepted = self.metropolis_is_accepted(new_prob, old_prob) prob = tf.cond(tf.equal(is_accepted, tf.constant(True)), lambda: new_prob, lambda: old_prob) new_state = tf.cond(tf.equal(is_accepted, tf.constant(False)), lambda: new_state, lambda: test_state) return new_state, prob, is_accepted[0]
def _posterior(self, prev_state, prev_action, obs): """Compute posterior state from previous state and current observation.""" # Recurrent encoder. encoder_inputs = [obs, prev_action] if self._sample_to_encoder: encoder_inputs.append(prev_state['sample']) if self._decoder_to_encoder: encoder_inputs.append(prev_state['decoder_state']) encoded, encoder_state = self._encoder_cell( tf.concat(encoder_inputs, -1), prev_state['encoder_state']) # Sample sequence. sample_inputs = [encoded] if self._sample_to_sample: sample_inputs.append(prev_state['sample']) if self._decoder_to_sample: sample_inputs.append(prev_state['decoder_state']) hidden = tf.layers.dense(tf.concat(sample_inputs, -1), **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() # Recurrent decoder. decoder_inputs = [sample] if self._encoder_to_decoder: decoder_inputs.append(prev_state['encoder_state']) if self._action_to_decoder: decoder_inputs.append(prev_action) decoded, decoder_state = self._decoder_cell( tf.concat(decoder_inputs, -1), prev_state['decoder_state']) return { 'encoder_state': encoder_state, 'decoder_state': decoder_state, 'mean': mean, 'stddev': stddev, 'sample': sample, }
def encoder(self, inputs): with tf.variable_scope('encoder'): #with tf.device('/cpu:0'): g_sum = relation_sum(inputs) f_out = f_net(g_sum) loc = tf.layers.dense(f_out, self.FLAGS['z_size'], activation=None, name='fc_mu') log_scale = tf.layers.dense(f_out, self.FLAGS['z_size'], activation=None, name='fc_log_var') scale = tf.nn.softplus( log_scale + softplus_inverse(1.0) ) # idk what this is for. maybe ensuring center around 1.0 return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale, name='code')
def create_distribution_layer(mean_and_log_std): mean, log_std = tf.split(mean_and_log_std, num_or_size_splits=2, axis=1) log_std = tf.clip_by_value(log_std, -20., 2.) distribution = distributions.MultivariateNormalDiag( loc=mean, scale_diag=tf.exp(log_std)) raw_actions = distribution.sample() if not self._reparameterize: raw_actions = tf.stop_gradient(raw_actions) log_probs = distribution.log_prob(raw_actions) log_probs -= self._squash_correction(raw_actions) ### Problem 2.A ### YOUR CODE HERE actions = tf.tanh(raw_actions) #actions = raw_actions return actions, log_probs
def _transition(self, prev_state, prev_action, zero_obs): hidden = tf.concat([prev_state['sample'], prev_action], -1) for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) if self._model == 'gru': belief, rnn_state = self._cell(hidden, prev_state['rnn_state']) else: prev_state = tf.reshape(prev_state['rnn_state'], [prev_state['rnn_state'].shape[0], self._trxl_mem_len, self._trxl_layer, self._belief_size]) prev_state = tf.transpose(prev_state, perm=[2,1,0,3]) belief, state = trxl(dec_inp=tf.expand_dims(hidden, axis=0), mems=prev_state, d_model=self._belief_size, n_head=self._trxl_n_head, d_head=self._belief_size//self._trxl_n_head, d_inner=self._belief_size, mem_len=self._trxl_mem_len, pre_lnorm=self._trxl_pre_lnorm, gate=self._trxl_gate) state = tf.transpose(state, perm=[2,1,0,3]) state = tf.reshape(state, [state.shape[0], -1]) rnn_state = state if self._future_rnn: hidden = belief for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, 'belief': belief, 'rnn_state': rnn_state, }
def test_identity(self): # Test that an additive SSM with a single component defines the same # distribution as the component model. y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8]) local_ssm = LocalLinearTrendStateSpaceModel( num_timesteps=5, level_scale=0.3, slope_scale=0.6, observation_noise_scale=0.1, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1., 1.]))) additive_ssm = AdditiveStateSpaceModel([local_ssm]) local_lp = local_ssm.log_prob(y[:, np.newaxis]) additive_lp = additive_ssm.log_prob(y[:, np.newaxis]) self.assertAllClose(self.evaluate(local_lp), self.evaluate(additive_lp))
def __call__(self, xs, ys): print(xs) print(ys) xys = tf.concat([xs, ys], axis=1) # encoder mlp inner_layer_dims = self.layer_dims[:-1] output_dim = self.layer_dims[-1] rs = batch_mlp(xys, inner_layer_dims, output_dim, "encoder") # aggregate rs r = self._aggregate_r(rs) # get mu and sigma z_params = self._get_z_params(r) # distribution dist = tfd.MultivariateNormalDiag(loc=z_params.mu, scale_diag=z_params.sigma) return dist
def _transition(self, prev_state, prev_action, zero_obs): """Compute prior next state by applying the transition dynamics.""" inputs = tf.concat([prev_state['sample'], prev_action], -1) hidden = tf.layers.dense(inputs, **self._kwargs) belief, rnn_state = self._cell(hidden, prev_state['rnn_state']) hidden = tf.layers.dense(hidden, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, 'belief': belief, 'rnn_state': rnn_state, }
def _init_distribution(conditions: dict, **kwargs): num_timesteps = conditions["num_timesteps"] coefficients = conditions["coefficients"] level_scale = conditions["level_scale"] initial_state = conditions["initial_state"] initial_step = conditions["initial_step"] coefficients = tf.convert_to_tensor(value=coefficients, name="coefficients") order = tf.compat.dimension_value(coefficients.shape[-1]) time_series_object = sts.Autoregressive(order=order) distribution = time_series_object.make_state_space_model( num_timesteps=num_timesteps, param_vals={"coefficients": coefficients, "level_scale": level_scale}, initial_state_prior=tfd.MultivariateNormalDiag( loc=initial_state, scale_diag=[1e-6] * order ), initial_step=initial_step, ) return distribution
def _create_prior(self): if self._learn_prior == LEARN_PRIOR_MAF: self._prior = IndependentChannelsTransformedDistribution( [1, self.time_dim_latent], self._flow_params, self.time_dim_latent, self._latent_channels ) elif self._variational_layer == GAUSSIAN_DIAGONAL: dim = tf.shape(self._gaussian_model_latent.mean())[1:3] zeros = tf.zeros(shape=dim) ones = self._alpha * tf.ones(shape=dim) self._prior = tfd.MultivariateNormalDiag(loc=zeros, scale_diag=ones, name="prior") elif self._variational_layer == GAUSSIAN_TRI_DIAGONAL: dim = tf.shape(self._gaussian_model_latent.mean())[1:3] ch_z = self._gaussian_model_latent.mean().shape[1] zeros = tf.zeros(shape=dim) ones = tf.stack([tf.eye(dim[1]) for i in range(ch_z)]) twos = self._alpha * 0.2 * tf.stack([tf.eye(dim[1] - 1) for i in range(ch_z)]) twos_big = tf.pad(twos, [[0, 0], [1, 0], [0, 1]], mode='CONSTANT') cov = ones + twos_big + tf.transpose(twos_big, perm=[0, 2, 1]) self._prior = tfd.MultivariateNormalFullCovariance(loc=zeros, covariance_matrix=cov, name="prior") elif self._variational_layer == GAUSSIAN_TRI_DIAGONAL_PRECISION: dim = tf.shape(self._gaussian_model_latent.mean())[1:3] ch_z = self._gaussian_model_latent.mean().shape[1] zeros = tf.zeros(shape=dim) ones = tf.stack([tf.eye(dim[1]) for i in range(ch_z)]) twos = self._alpha * tf.stack([tf.eye(dim[1] - 1) for i in range(ch_z)]) twos_big = tf.pad(twos, [[0, 0], [1, 0], [0, 1]], mode='CONSTANT') L = ones + twos_big L_T = tf.transpose(L, perm=[0, 2, 1]) prec = tf.matmul(L, L_T) cov = tf.linalg.inv(prec) self._prior = tfd.MultivariateNormalFullCovariance(loc=zeros, covariance_matrix=cov, name="prior") else: raise ValueError("specified string is not a suitable variational layer: %s" % str(self._variational_layer))
def gmm_elk_cost(vec_mus, vec_scales, mixing_coeffs, sample_valid, eps=1e-30): n_comps = mixing_coeffs.get_shape().as_list()[1] mus = tf.split(vec_mus, num_or_size_splits=n_comps, axis=1) scales = tf.split(vec_scales, num_or_size_splits=n_comps, axis=1) entlb = 0 for i in range(n_comps): elk = 0 for j in range(n_comps): p_j = tfd.MultivariateNormalDiag( loc=mus[j], scale_diag=0 * scales[i] + 0.1 ) #scales[i] + scales[j] + 1e-5) # 0 * scales[i] + 1.0) #scales[j]+scales[i]) prob = p_j.prob(mus[i]) prob = tf.clip_by_value(prob, 0, 10) elk = elk + tf.multiply(mixing_coeffs[:, j], prob) entlb = entlb + tf.multiply(mixing_coeffs[:, i], tf.math.log(elk + eps)) loss = tf.reduce_mean(entlb) return loss
def mdn_loss_func(y_true, y_pred): # Reshape inputs in case this is used in a TimeDistribued layer y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') # Split the inputs into paramaters out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, num_mixes * output_dim, num_mixes], axis=-1, name='mdn_coef_split') # Construct the mixture models cat = tfd.Categorical(logits=out_pi) component_splits = [output_dim] * num_mixes mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs)] mixture = tfd.Mixture(cat=cat, components=coll) loss = mixture.log_prob(y_true) loss = tf.negative(loss) loss = tf.reduce_mean(loss) return loss
def test_sampled_level_has_correct_marginals(self): seed = test_util.test_seed(sampler_type='stateless') residuals_seed, is_missing_seed, level_seed = samplers.split_seed( seed, 3, 'test_sampled_level_has_correct_marginals') num_timesteps = 10 initial_state_prior = tfd.MultivariateNormalDiag(loc=[-30.], scale_diag=[100.]) observed_residuals = samplers.normal([3, 1, num_timesteps], seed=residuals_seed) is_missing = samplers.uniform([3, 1, num_timesteps], seed=is_missing_seed) > 0.8 level_scale = 1.5 * tf.ones([3, 1]) observation_noise_scale = 0.2 * tf.ones([3, 1]) ssm = tfp.sts.LocalLevelStateSpaceModel( num_timesteps=num_timesteps, initial_state_prior=initial_state_prior, observation_noise_scale=observation_noise_scale, level_scale=level_scale) resample_level = gibbs_sampler._build_resample_level_fn( initial_state_prior, is_missing=is_missing) posterior_means, posterior_covs = ssm.posterior_marginals( observed_residuals[..., tf.newaxis], mask=is_missing) level_samples = resample_level( observed_residuals=observed_residuals, level_scale=level_scale, observation_noise_scale=observation_noise_scale, sample_shape=10000, seed=level_seed) posterior_means_, posterior_covs_, level_samples_ = self.evaluate( (posterior_means, posterior_covs, level_samples)) self.assertAllClose(np.mean(level_samples_, axis=0), posterior_means_[..., 0], atol=0.1) self.assertAllClose(np.std(level_samples_, axis=0), np.sqrt(posterior_covs_[..., 0, 0]), atol=0.1)
def train_step(self, x): """Perform a training step of gradient descent on an ensemble using bootstrap weights for each model in the ensemble Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() with tf.GradientTape() as tape: latent = self.vae.encode(x, training=True) z = latent.mean() prediction = self.vae.decode(z) nll = -prediction.log_prob(x) kld = latent.kl_divergence( tfpd.MultivariateNormalDiag(loc=tf.zeros_like(z), scale_diag=tf.ones_like(z))) total_loss = tf.reduce_mean(nll) + tf.reduce_mean(kld) * self.beta variables = self.vae.trainable_variables self.vae_optim.apply_gradients( zip(tape.gradient(total_loss, variables), variables)) statistics[f'vae/train/nll'] = nll statistics[f'vae/train/kld'] = kld return statistics
def compute(self, alpha, x, x_replicate, x_reconstr_mean_samples, z): """ implement Renyi cost function as in the paper Renyi Divergence Variational Inference, formula [5], with alpha in (0, 1) alpha = 0 => log p(x) alpha -> 1 => KL lower bound """ # log conditional p(x|z) if self._binary==1: self.log_p_x_z = -tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits( labels=x_replicate, logits=x_reconstr_mean_samples), 2) else: assert self._synthetic==1 self.log_p_x_z = self._gaussian_model_observed.log_pdf(x_replicate) self.log_p_x_z = tf.reshape(self.log_p_x_z, [-1, self._gaussian_model_latent._s]) if not self._synthetic: raise AssertionError("Renyi bound not yet implemented for continuous real dataset (e.g. MNISTcontinuos") # TODO TO BE DELETED # ignore first parameter, which is the log_pdfs_latent #_, z = self._gaussian_model_latent.sampling(n_z_samples) # log posterior q(z|x) log_q_z_x = self._gaussian_model_latent.log_pdf(z) log_q_z_x = tf.reshape(log_q_z_x, [-1, self._gaussian_model_latent._s]) # log prior p(z) n_z = self._gaussian_model_latent._n_z log_p_z = tfd.MultivariateNormalDiag(tf.zeros(n_z), tf.ones(n_z))._log_prob(z) log_p_z = tf.reshape(log_p_z, [-1, self._gaussian_model_latent._s]) # exponent for Renyi expectation -- to avoid numerical issues we use exp of log # use log sum exp trick exponent = (1 - alpha) * (self.log_p_x_z + log_p_z - log_q_z_x) max_exponent= tf.reduce_max(exponent, 1, keep_dims=True) renyi_sum = tf.log(tf.reduce_mean(tf.exp(exponent - max_exponent), 1)) + tf.reduce_mean(max_exponent, 1) renyi_sum = -renyi_sum/(1 - alpha) return renyi_sum
def rnn_sim(rnn, z, states, a, training=True): z = tf.reshape(tf.cast(z, dtype=tf.float32), (1, 1, rnn.args.z_size)) a = tf.reshape(tf.cast(a, dtype=tf.float32), (1, 1, rnn.args.a_width)) input_x = tf.concat((z, a), axis=2) rnn_out, h, c = rnn.inference_base( input_x, initial_state=states, training=training) # set training True to use Dropout rnn_state = [h, c] rnn_out = tf.reshape(rnn_out, [-1, rnn.args.rnn_size]) out = rnn.out_net(rnn_out) mdnrnn_params, r, d_logits = rnn.parse_rnn_out(out) mdnrnn_params = tf.reshape(mdnrnn_params, [-1, 3 * rnn.args.rnn_num_mixture]) mu, logstd, logpi = tf.split(mdnrnn_params, num_or_size_splits=3, axis=1) logpi = logpi / rnn.args.rnn_temperature # temperature logpi = logpi - tf.reduce_logsumexp( input_tensor=logpi, axis=1, keepdims=True) # normalize d_dist = tfd.Binomial(total_count=1, logits=d_logits) d = tf.squeeze(d_dist.sample()) == 1.0 cat = tfd.Categorical(logits=logpi) component_splits = [1] * rnn.args.rnn_num_mixture mus = tf.split(mu, num_or_size_splits=component_splits, axis=1) # temperature sigs = tf.split(tf.exp(logstd) * tf.sqrt(rnn.args.rnn_temperature), component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) z = tf.reshape(mixture.sample(), shape=(-1, rnn.args.z_size)) if rnn.args.rnn_r_pred == 0: r = 1.0 # For Doom Reward is always 1.0 if the agent is alive return rnn_state, z, r, d
def testLogprobCorrectness(self): # Compare the state-space model's log-prob to an explicit implementation. num_timesteps = 10 observed_time_series_ = np.random.randn(num_timesteps) coefficients_ = np.array([.7, -.1]).astype(self.dtype) level_scale_ = 1.0 observed_time_series = self._build_placeholder(observed_time_series_) level_scale = self._build_placeholder(level_scale_) expected_logp = ar_explicit_logp(observed_time_series_, coefficients_, level_scale_) ssm = AutoregressiveStateSpaceModel( num_timesteps=num_timesteps, coefficients=coefficients_, level_scale=level_scale, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=[level_scale, 0.])) lp = ssm.log_prob(observed_time_series[..., tf.newaxis]) self.assertAllClose(self.evaluate(lp), expected_logp)
def encode(self): """ :return: void """ with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE): mean = tf.layers.dense( tf.layers.dense( tf.layers.dense( tf.layers.flatten(self.data), 784, tf.nn.relu ), 256, tf.nn.relu ), self.code_units ) variance = tf.layers.dense( tf.layers.dense( tf.layers.dense( tf.layers.flatten(self.data), 784, tf.nn.relu ), 256, tf.nn.relu ), self.code_units, tf.nn.softplus ) return distributions.MultivariateNormalDiag(mean, variance)
def test_noiseless_is_consistent_with_cumsum_bijector(self): num_timesteps = 10 ssm = AutoregressiveMovingAverageStateSpaceModel( num_timesteps=num_timesteps, ar_coefficients=[0.7, -0.2, 0.1], ma_coefficients=[0.6], level_scale=0.6, level_drift=-0.3, observation_noise_scale=0., initial_state_prior=tfd.MultivariateNormalDiag(loc=tf.zeros([3]), scale_diag=tf.ones( [3]))) cumsum_ssm = IntegratedStateSpaceModel(ssm) x, lp = cumsum_ssm.experimental_sample_and_log_prob( [2], seed=test_util.test_seed()) flatten_event = tfb.Reshape([num_timesteps], event_shape_in=[num_timesteps, 1]) cumsum_dist = tfb.Chain( [tfb.Invert(flatten_event), tfb.Cumsum(), flatten_event])(ssm) self.assertAllClose(lp, cumsum_dist.log_prob(x), atol=1e-5)
def rnn_sim(rnn, z, states, a): if rnn.args.env_name == 'CarRacing-v0': raise ValueError('Not implemented yet for CarRacing') z = tf.reshape(tf.cast(z, dtype=tf.float32), (1, 1, rnn.args.z_size)) a = tf.reshape(tf.cast(a, dtype=tf.float32), (1, 1, rnn.args.a_width)) input_x = tf.concat((z, a), axis=2) rnn_out, h, c = rnn.inference_base(input_x, initial_state=states) rnn_state = [h, c] rnn_out = tf.reshape(rnn_out, [-1, rnn.args.rnn_size]) out = rnn.out_net(rnn_out) mdnrnn_params, d_logits = out[:, :-1], out[:, -1:] mdnrnn_params = tf.reshape(mdnrnn_params, [-1, 3 * rnn.args.rnn_num_mixture]) mu, logstd, logpi = tf.split(mdnrnn_params, num_or_size_splits=3, axis=1) logpi = logpi - tf.reduce_logsumexp( input_tensor=logpi, axis=1, keepdims=True) # normalize d_dist = tfd.Binomial(total_count=1, logits=d_logits) d = tf.squeeze(d_dist.sample()) == 1.0 cat = tfd.Categorical(logits=logpi) component_splits = [1] * rnn.args.rnn_num_mixture mus = tf.split(mu, num_or_size_splits=component_splits, axis=1) sigs = tf.split(tf.exp(logstd), num_or_size_splits=component_splits, axis=1) coll = [ tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus, sigs) ] mixture = tfd.Mixture(cat=cat, components=coll) z = tf.reshape(mixture.sample(), shape=(-1, rnn.args.z_size)) r = 1.0 # For Doom Reward is always 1.0 if the agent is alive return rnn_state, z, r, d
def __init__(self, units, make_dist_fn=None, make_dist_model=None, **kwargs): super(VariationalLSTMCell, self).__init__(units, **kwargs) self.make_dist_fn = make_dist_fn self.make_dist_model = make_dist_model # For some reason the below code doesn't work during build. # So I don't know how to use the outer VariationalRNN to set this cell's output_size if self.make_dist_fn is None: self.make_dist_fn = lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1]) if self.make_dist_model is None: fake_cell_output = tfkl.Input((self.units,)) loc = tfkl.Dense(self.output_size, name="VarLSTMCell_loc")(fake_cell_output) scale = tfkl.Dense(self.output_size, name="VarLSTMCell_scale")(fake_cell_output) scale = tf.nn.softplus(scale + scale_shift) + 1e-5 dist_layer = tfpl.DistributionLambda( make_distribution_fn=self.make_dist_fn, # TODO: convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES) )([loc, scale]) self.make_dist_model = tf.keras.Model(fake_cell_output, dist_layer)
def get_dist(self, timesteps, samples=1, batch_size=1): """ Tiles the saved loc and scale to the same shape as `posterior` then uses them to create a MVN dist with appropriate shape. Each timestep has the same loc and scale but if it were sampled then each timestep would return different values. Args: timesteps: samples: batch_size: Returns: MVNDiag distribution of the same shape as `posterior` """ loc = tf.tile(tf.expand_dims(self._loc, 0), [timesteps, 1]) scale = tf.expand_dims(self._scale, 0) if self._offdiag: scale = tf.tile(scale, [timesteps, 1, 1]) dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale) else: scale = tf.tile(scale, [timesteps, 1]) dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) dist = tfd.Independent(dist, reinterpreted_batch_ndims=1) return dist.sample([samples, batch_size]), dist
def _build(self, inputs): mean, covariance, scale = self.create_mean_n_cov_layers(inputs) mean = tf.transpose(mean, perm=[0, 2, 1]) scale = tf.transpose(scale, perm=[0, 2, 1]) self.set_contractive_regularizer( mean, covariance, self._contractive_regularizer_inputs, self._contractive_regularizer_tuple, self._contractive_collection_network_str) output_distribution = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale) # add reconstruction_node method (needed to some sort of mean or median to get reconstructions without sampling) def reconstruction_node(self): return self.mean() output_distribution.reconstruction_node = types.MethodType( reconstruction_node, output_distribution) return output_distribution
def validate_step(self, x, y): """Perform a validation step on an ensemble of models without using bootstrapping weights Args: x: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() # build distributions for the data x and latent variable z dz = self.encoder.get_distribution(x, training=False) z = dz.sample() dx = self.decoder.get_distribution(z, training=False) # build the reconstruction loss nll = -dx.log_prob(x)[..., tf.newaxis] while len(nll.shape) > 2: nll = tf.reduce_sum(nll, axis=1) prior = tfpd.MultivariateNormalDiag(loc=tf.zeros_like(z), scale_diag=tf.ones_like(z)) # build the kl loss kl = dz.kl_divergence(prior)[:, tf.newaxis] statistics[f'vae/validate/nll'] = nll statistics[f'vae/validate/kl'] = kl return statistics
def _transition(self, prev_state, prev_action, zero_obs): hidden = tf.concat([prev_state['sample'], prev_action], -1) for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) belief, rnn_state = self._cell(hidden, prev_state['rnn_state']) if self._future_rnn: hidden = belief for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, 'belief': belief, 'rnn_state': rnn_state, }
def _posterior(self, prev_state, prev_action, obs): """Compute posterior state from previous state and current observation.""" prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) hidden = tf.concat( [prior['belief'], obs], -1) # TODO: why take belief instead of previous hidden state? for _ in range(self._num_layers): hidden = tf.layers.dense(hidden, **self._kwargs) mean = tf.layers.dense(hidden, self._state_size, None) stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) stddev += self._min_stddev if self._mean_only: sample = mean else: sample = tfd.MultivariateNormalDiag(mean, stddev).sample() return { 'mean': mean, 'stddev': stddev, 'sample': sample, 'belief': prior['belief'], 'rnn_state': prior['rnn_state'], }