def latent(self, conditioning): """ z_enc: [batch, frames, latent_dims] attributes: [batch, frames, attribute_dims] """ z_enc = conditioning['z'] attributes = conditioning['attributes'] batch_size, n_frames, _encdims = z_enc.shape if len(attributes.shape) < 3: # expand along frame dimension attributes = attributes.unsqueeze(1).expand(-1, n_frames, -1) mu_q, logscale_q = self.psi_q(z_enc, z_enc) # mix with latent mu_a, logscale_a = self.attribute_latent(mu_q, logscale_q, attributes) # feed into temporal model h_0 = torch.rand(1, batch_size, self.latent_dims).to( z_enc.device) * 0.01 temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0) # posterior distribution mu_z, scale_z = self.mix_with_temp(mu_a, logscale_a, temp_q) posterior = Independent(Normal(mu_z, scale_z), 1) posterior_sample = posterior.rsample() # prior output = torch.cat([temp_q, attributes], dim=-1) mu, scale = self.psi_p(output, output) scale = scale.exp() prior = Independent(Normal(mu, scale), 1) # sum over latent dim mean over batch kl_div = torch.mean(kl_divergence(posterior, prior)) return posterior_sample, kl_div, [mu_z, scale_z]
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias # __eps is used to avoid log of zero/negative number. y = self._action_scale * (1 - y.pow(2)) + self.__eps # Compute logprob from Gaussian, and then apply correction for Tanh squashing. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def trajectory(self, current_state): ''' Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not. Final output looks like: [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)] ''' output_history = [] while True: mu, sigma = self.forward(current_state) distribution = Independent(Normal(mu, sigma), 1) picked_action = distribution.rsample() action = picked_action.detach() #print(action) new_state, reward = self.env.state_and_reward( current_state, action ) #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment #Attempting this output_history.append( (current_state, action, reward, distribution.log_prob(action))) if new_state is None: #essentially, you died or finished your trajectory break else: current_state = new_state return output_history
def sample_trajectories(self, init_std=1.0, min_std=1e-6, output_size=2): # 基于当前策略,采样 batch_size 个完整的轨迹 observations = self.envs.reset() with torch.no_grad(): while not self.envs.dones.all(): observations_tensor = torch.from_numpy(observations) """ ****************************************************************** """ output = self.policy(observations_tensor) min_log_std = math.log(min_std) sigma = nn.Parameter(torch.Tensor(output_size)) sigma.data.fill_(math.log(init_std)) scale = torch.exp(torch.clamp(sigma, min=min_log_std)) # loc 是高斯分布均值 # scale 是高斯分布方差 p_normal = Independent(Normal(loc=output, scale=scale), 1) actions_tensor = p_normal.sample() actions = actions_tensor.cpu().numpy() # pi = policy(observations_tensor) # actions_tensor = pi.sample() # actions = actions_tensor.cpu().numpy() new_observations, rewards, _, infos = self.envs.step(actions) batch_ids = infos['batch_ids'] yield (observations, actions, rewards, batch_ids) observations = new_observations
def test_independent_expand(self): for Dist, params in EXAMPLES: for param in params: base_dist = Dist(**param) for reinterpreted_batch_ndims in range( len(base_dist.batch_shape) + 1): for s in [ torch.Size(), torch.Size((2, )), torch.Size((2, 3)) ]: indep_dist = Independent(base_dist, reinterpreted_batch_ndims) expanded_shape = s + indep_dist.batch_shape expanded = indep_dist.expand(expanded_shape) expanded_sample = expanded.sample() expected_shape = expanded_shape + indep_dist.event_shape self.assertEqual(expanded_sample.shape, expected_shape) self.assertEqual( expanded.log_prob(expanded_sample), indep_dist.log_prob(expanded_sample), ) self.assertEqual(expanded.event_shape, indep_dist.event_shape) self.assertEqual(expanded.batch_shape, expanded_shape)
def KLD(self, a, b, prior_alpha, prior_beta): eps = 5 * torch.finfo(torch.float).eps a = a.clamp(eps) b = b.clamp(eps) if self.dist == "km": ab = (a * b) + eps kl = 1 / (1 + ab) * self.Beta(1 / a, b) kl += 1 / (2 + ab) * self.Beta(2 / a, b) kl += 1 / (3 + ab) * self.Beta(3 / a, b) kl += 1 / (4 + ab) * self.Beta(4 / a, b) kl += 1 / (5 + ab) * self.Beta(5 / a, b) kl += 1 / (6 + ab) * self.Beta(6 / a, b) kl += 1 / (7 + ab) * self.Beta(7 / a, b) kl += 1 / (8 + ab) * self.Beta(8 / a, b) kl += 1 / (9 + ab) * self.Beta(9 / a, b) kl += 1 / (10 + ab) * self.Beta(10 / a, b) kl *= (prior_beta - 1) * b kl += (a - prior_alpha) / a * (-np.euler_gamma - torch.digamma( b) - 1 / b) # T.psi(self.posterior_b) # add normalization constants kl += torch.log(ab) + torch.log(self.Beta(prior_alpha, prior_beta)) # final term kl += -(b - 1) / b elif self.dist == "gamma": kl = torch.distributions.kl.kl_divergence(Gamma(a, b), Gamma(prior_alpha, prior_beta)) elif self.dist == "gl": prior_alpha_beta = prior_alpha/(prior_alpha + prior_beta) prior_beta_beta = torch.sqrt(prior_alpha*prior_beta / ((prior_alpha + prior_beta)**2*(prior_alpha + prior_beta + 1))) kl = torch.distributions.kl.kl_divergence(Independent(Normal(a, b),1), Independent(Normal(prior_alpha_beta, prior_beta_beta),1)).unsqueeze(1) return kl
def temporal_model_step(self, mu_q_t, logscale_q_t, h, attribute=None): """ generate z_t autoregressively """ # mix with temporal info h_mu, h_scale = self.h_process(h, h) mu = torch.cat([mu_q_t, h_mu], dim=-1) logscale = torch.cat([logscale_q_t, h_scale], dim=-1) mu_z_t, logscale_z_t = self.psi_dy(mu, logscale) scale_z_t = logscale_z_t.exp() # final posterior distribution with rnn information posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1) posterior_sample_t = posterior_t.rsample() # prior distribution with rnn information if not attribute is None: mixed_h_mu = torch.cat([h_mu, attribute], dim=-1) mixed_h_scale = torch.cat([h_scale, attribute], dim=-1) mu_p_t, logscale_p_t = self.psi_p(mixed_h_mu, mixed_h_scale) else: mu_p_t, logscale_p_t = self.psi_p(h_mu, h_scale) # scale_p_t = logscale_p_t.exp() scale_p_t = torch.ones_like(logscale_p_t) prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t)) return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
def forward(self, state, eval=False, with_log_prob=False): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) mu = self.fc3(x) log_sigma = self.fc4(x) # clip value of log_sigma, as was done in Haarnoja's implementation of SAC: # https://github.com/haarnoja/sac.git log_sigma = torch.clamp(log_sigma, -20.0, 2.0) sigma = torch.exp(log_sigma) distribution = Independent(Normal(mu, sigma), 1) if not eval: # use rsample() instead of sample(), as sample() does not allow back-propagation through params u = distribution.rsample() if with_log_prob: log_prob = distribution.log_prob(u) log_prob -= 2.0 * torch.sum( (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u - F.softplus(-2.0 * u)), dim=1) else: log_prob = None else: u = mu log_prob = None # apply tanh so that the resulting action lies in (-1, 1)^D a = self.ctrl_range * torch.tanh(u) return a, log_prob
def get_loss(self, output, target, ignore_label=None): loss = super().get_loss(output, target, ignore_label) eps = 1e-12 # construct N(0,1) diagonal covariance of size y (output) # construct N(0,1) diagonal covariance of size y (output) normal = Independent( Normal(loc=torch.FloatTensor( output['logits'].size()).fill_(0).to(device), scale=torch.FloatTensor(output['logits'].size()).fill_( self.sigma).to(device)), 1) # sum ( softmax (distorted softmax probs)) using predicted voxel variances (scale) # we then take the log of these sum_distorted_softmax = torch.sum(torch.stack([ self.softmax(output['logits'] + (output['sigma'] * normal.sample())) for _ in torch.arange(self.samples) ]), dim=0) # sum_distorted_softmax should have shape [batch, nclasses, x, y] one_hot = torch.zeros(output['logits'].shape).scatter_( 1, target.unsqueeze(1), 1) # mask sum_distorted_softmax in order to obtain only the softmax probs for the gt class and take max # of the result, which will just select the prob of the gt class (reduce dim 1=nclasses) sum_distorted_softmax, _ = torch.max(sum_distorted_softmax * one_hot, 1) # sum_distorted_softmax should now have shape [batch, x, y] # finally compute the categorical aleatoric loss aleatoric_loss = -0.0001 * torch.mean(torch.sum( torch.log(sum_distorted_softmax + eps) - np.log(self.samples), dim=(1, 2)), dim=0) output['sigma'] = output['sigma'].cpu().detach().numpy() output['logits'] = None self.current_aleatoric_loss = aleatoric_loss.detach().cpu().numpy() return loss + aleatoric_loss
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: act = logits[0] else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( (self.action_space.high - self.action_space.low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch(logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def choose_action(self, observation): mu, sigma = self.actor.forward(observation)#.to(self.actor.device) sigma = T.exp(sigma) action_probs = Independent(Normal(mu, sigma),1) probs = action_probs.sample() self.log_probs = action_probs.log_prob(probs).to(self.actor.device) return probs
def generate(self, synth, h_0, f0_hz, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial state of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device) for t in range(n_frames): # prior distribution with rnn information mu_p_t, scale_p_t = self.get_prior(h) prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) prior_sample_t = prior_t.rsample() h = self.temporal(prior_sample_t, h) z[:, t, :] = prior_sample_t cond = {} cond['z'] = z cond['f0_hz'] = f0_hz y_params = self.decode(cond) params = synth.fill_params(y_params, cond) resyn_audio, outputs = synth(params, n_samples) return params, resyn_audio
class IndependentNormal(Distribution): arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def forward(self, input, kl_coef): b, m, train_m, test_m = input mean, std = self.ode_rnn(input) # (batch_size, LO_hidden_size) * 2 d = Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device'])) r = d.sample(mean.shape).squeeze(-1) z0 = mean + r * std z_out = odeint(self.ode_func, z0, b[0, :, 0], rtol = self.param['rtol'], atol = self.param['atol']) # (num_time_points, batch_size, LO_hidden_size) z_out = z_out.permute(1, 0, 2) output = self.output_output(z_out).squeeze(2) z0_distr = Normal(mean, std) kl_div = kl_divergence(z0_distr, Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device']))) kl_div = kl_div.mean(axis = 1) masked_output = output[test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points'])) target = b[:, :, 1][test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points'])) gaussian = Independent(Normal(loc = masked_output, scale = self.param['obsrv_std']), 1) log_prob = gaussian.log_prob(target) likelihood = log_prob / output.shape[1] loss = -torch.logsumexp(likelihood - kl_coef * kl_div, 0) mse = self.mse_func(masked_output, target) return loss, mse, masked_output
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
class TanhNormal(Distribution): """Copied from Kaixhi""" def __init__(self, loc, scale): super().__init__() self.normal = Independent(Normal(loc, scale), 1) def sample(self): return torch.tanh(self.normal.sample()) # samples with re-parametrization trick (differentiable) def rsample(self): return torch.tanh(self.normal.rsample()) # Calculates log probability of value using the change-of-variables technique # (uses log1p = log(1 + x) for extra numerical stability) def log_prob(self, value): inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2 # artanh(y) # log p(f^-1(y)) + log |det(J(f^-1(y)))| return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) + 1e-6).sum(dim=1) @property def mean(self): return torch.tanh(self.normal.mean) def get_std(self): return self.normal.stddev
def forward(self, x): kl = torch.zeros(1).to(device) z = 0. # Unet encoder result x_enc = self.unet_encoder(x) # VAE regularisation if not self.unet: emb = self.vae_encoder(x) # Split encoder outputs into a mean and variance vector mu, log_var = torch.chunk(emb, 2, dim=1) # Make sure that the log variance is positive log_var = softplus(log_var) sigma = torch.exp(log_var / 2) # Instantiate a diagonal Gaussian with mean=mu, std=sigma # This is the approximate latent distribution q(z|x) posterior = Independent(Normal(loc=mu, scale=sigma), 1) z = posterior.rsample() # Instantiate a standard Gaussian with mean=mu_0, std=sigma_0 # This is the prior distribution p(z) prior = Independent(Normal(loc=self.mu_0, scale=self.sigma_0), 1) # Estimate the KLD between q(z|x)|| p(z) kl = KLD(posterior, prior).sum() # Outputs for MSE xHat = self.decoder(x_enc, z) return kl, xHat
def generate(self, synth, h_0, f0_hz, attributes, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial seed of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] attributes: attributes [batch, n_frames, attribute_size] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ if len(h_0.shape) == 2: h = h_0[None, :, :] # 1, batch, latent_dims else: h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] for i in range(n_frames): cond = {} output = torch.cat([h.permute(1, 0, 2), attributes], dim=-1) mu, logscale = self.psi_p(output, output) scale = logscale.exp() prior = Independent(Normal(mu, scale), 1) prior_sample = prior.rsample() cond['z'] = prior_sample cond['f0_hz'] = f0_hz[:, i, :].unsqueeze(1) cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0 # generate x y = self.decode(cond) params = synth.fill_params(y, cond) params_list.append(params) x_tilde, _outputs = synth( params, n_samples=n_fft) # write exactly one frame cond['audio'] = x_tilde # encode cond = self.encoder(cond) z_enc = cond['z'] # get psi_q mu, logscale = self.psi_q(z_enc, z_enc) psi = torch.cat([mu, logscale], dim=-1) # temporal model temp_q, h = self.temporal_q(psi, h) # one off param_names = params_list[0].keys() final_params = {} for pn in param_names: #cat over frames final_params[pn] = torch.cat([par[pn] for par in params_list], dim=1) final_audio, _outputs = synth(final_params, n_samples=n_samples) return final_params, final_audio
def initialize(self, data, ndim): self._mean = torch.zeros((data.shape[0] + 1, ndim), requires_grad=True) self._logstd = torch.zeros_like(self._mean, requires_grad=True) # ===== Start optimization ===== # self._sampledist = Independent(Normal(torch.zeros_like(self._mean), torch.ones_like(self._logstd)), 2) return self
def get_ELBO_per_obs(self, batch, beta=1.0): output = self(batch) px, pz, qz, z = [output[k] for k in ["px", "pz", "qz", "z"]] kl_term = kl_divergence(Independent(qz, 1), Independent(pz, 1)) elbo = px.log_prob(batch).sum(-1) + beta * kl_term return elbo
def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args)
def forward(self, obs, act=None, deterministic=False): # Optionally pass in an action to get the log_prob of that action mu = self.mu_layer(obs) std = torch.exp(self.log_std_layer) pi = Independent(Normal(mu, std), 1) if act is None: act = pi.mean if deterministic else pi.rsample() log_prob = pi.log_prob(act) return pi, act, log_prob
def prop_state(x, f, g): bins = Independent(Binomial(x[:-1], f), 1) samp = bins.sample() s = x[0] - samp[..., 0] i = x[1] + samp[..., 0] - samp[..., 1] r = x[2] + samp[..., 1] return concater(s, i, r)
def _kl_independent_independent(p, q): shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims) p_ndims = p.reinterpreted_batch_ndims - shared_ndims q_ndims = q.reinterpreted_batch_ndims - shared_ndims p = Independent(p.base_dist, p_ndims) if p_ndims else p.base_dist q = Independent(q.base_dist, q_ndims) if q_ndims else q.base_dist kl = kl_divergence(p, q) if shared_ndims: kl = sum_rightmost(kl, shared_ndims) return kl
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None): n_data_points = mu_2d.size()[-1] if n_data_points > 0: gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1) log_prob = gaussian.log_prob(data_2d) log_prob = log_prob / n_data_points else: log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() return log_prob
def __init__(self, concentration1, concentration0, validate_args=None): self.base_dist = Independent(RescaledBeta(concentration1, concentration0, validate_args), len(concentration1.shape) - 1, validate_args=validate_args) super(IndependentRescaledBeta, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args)
def generate(self, synth, h_0, f0_hz, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial state of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device) for t in range(n_frames): h_mu, h_scale = self.h_process(h, h) mu_t, logscale_t = self.psi_p(h_mu, h_scale) # [batch, latent_size] scale_t = logscale_t.exp() prior_t = Independent(Normal(mu_t, scale_t), 1) prior_sample_t = prior_t.rsample() cond = {} z[:, t, :] = prior_sample_t cond['z'] = prior_sample_t.unsqueeze(1) cond['f0_hz'] = f0_hz[:, t, :].unsqueeze(1) cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0 # generate x y = self.decode(cond) params = synth.fill_params(y, cond) params_list.append(params) x_tilde, _outputs = synth( params, n_samples=n_fft) # write exactly one frame cond['audio'] = x_tilde # encode cond = self.encoder(cond) z_enc = cond['z'].squeeze(1) # get psi_q mu, logscale = self.psi_q(z_enc, z_enc) rnn_input = torch.cat([mu, logscale, prior_sample_t], dim=-1) # temporal model h = self.temporal_q(rnn_input, h) # one off cond = {} cond['z'] = z cond['f0_hz'] = f0_hz y_params = self.decode(cond) params = synth.fill_params(y_params, cond) resyn_audio, outputs = synth(params, n_samples) return params, resyn_audio
def get_loss(self, x, return_kl=False, beta=1.0): output = self(x) px, pz, qz, z = [output[k] for k in ["px", "pz", "qz", "z"]] kl_term = kl_divergence(Independent(qz, 1), Independent(pz, 1)) loss = -px.log_prob(x) + beta * kl_term if not return_kl: return loss.mean() else: return loss.mean(), kl_term.mean()
def prop_state(x, beta, gamma, eta, dt): f = _f(x, beta, gamma, eta, dt) bins = Independent(Binomial(x[..., :-1], f), 1) samp = bins.sample() s = x[..., 0] - samp[..., 0] i = x[..., 1] + samp[..., 0] - samp[..., 1] r = x[..., 2] + samp[..., 1] return concater(s, i, r)
def act(self, obs, deterministic=False): action_mean = self.forward(obs) normal = Normal(action_mean, torch.exp(self.log_scale)) dist = Independent(normal, 1) if deterministic: action = action_mean else: action = dist.rsample() action_logprobs = dist.log_prob(torch.squeeze(action)) return action, action_logprobs