def temporal_model_step(self, mu_q_t, logscale_q_t, h, attribute=None): """ generate z_t autoregressively """ # mix with temporal info h_mu, h_scale = self.h_process(h, h) mu = torch.cat([mu_q_t, h_mu], dim=-1) logscale = torch.cat([logscale_q_t, h_scale], dim=-1) mu_z_t, logscale_z_t = self.psi_dy(mu, logscale) scale_z_t = logscale_z_t.exp() # final posterior distribution with rnn information posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1) posterior_sample_t = posterior_t.rsample() # prior distribution with rnn information if not attribute is None: mixed_h_mu = torch.cat([h_mu, attribute], dim=-1) mixed_h_scale = torch.cat([h_scale, attribute], dim=-1) mu_p_t, logscale_p_t = self.psi_p(mixed_h_mu, mixed_h_scale) else: mu_p_t, logscale_p_t = self.psi_p(h_mu, h_scale) # scale_p_t = logscale_p_t.exp() scale_p_t = torch.ones_like(logscale_p_t) prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t)) return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def latent(self, conditioning): """ z_enc: [batch, frames, latent_dims] attributes: [batch, frames, attribute_dims] """ z_enc = conditioning['z'] attributes = conditioning['attributes'] batch_size, n_frames, _encdims = z_enc.shape if len(attributes.shape) < 3: # expand along frame dimension attributes = attributes.unsqueeze(1).expand(-1, n_frames, -1) mu_q, logscale_q = self.psi_q(z_enc, z_enc) # mix with latent mu_a, logscale_a = self.attribute_latent(mu_q, logscale_q, attributes) # feed into temporal model h_0 = torch.rand(1, batch_size, self.latent_dims).to( z_enc.device) * 0.01 temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0) # posterior distribution mu_z, scale_z = self.mix_with_temp(mu_a, logscale_a, temp_q) posterior = Independent(Normal(mu_z, scale_z), 1) posterior_sample = posterior.rsample() # prior output = torch.cat([temp_q, attributes], dim=-1) mu, scale = self.psi_p(output, output) scale = scale.exp() prior = Independent(Normal(mu, scale), 1) # sum over latent dim mean over batch kl_div = torch.mean(kl_divergence(posterior, prior)) return posterior_sample, kl_div, [mu_z, scale_z]
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: act = logits[0] else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( (self.action_space.high - self.action_space.low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch(logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias # __eps is used to avoid log of zero/negative number. y = self._action_scale * (1 - y.pow(2)) + self.__eps # Compute logprob from Gaussian, and then apply correction for Tanh squashing. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def forward(self, state, eval=False, with_log_prob=False): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) mu = self.fc3(x) log_sigma = self.fc4(x) # clip value of log_sigma, as was done in Haarnoja's implementation of SAC: # https://github.com/haarnoja/sac.git log_sigma = torch.clamp(log_sigma, -20.0, 2.0) sigma = torch.exp(log_sigma) distribution = Independent(Normal(mu, sigma), 1) if not eval: # use rsample() instead of sample(), as sample() does not allow back-propagation through params u = distribution.rsample() if with_log_prob: log_prob = distribution.log_prob(u) log_prob -= 2.0 * torch.sum( (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u - F.softplus(-2.0 * u)), dim=1) else: log_prob = None else: u = mu log_prob = None # apply tanh so that the resulting action lies in (-1, 1)^D a = self.ctrl_range * torch.tanh(u) return a, log_prob
def generate(self, synth, h_0, f0_hz, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial state of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device) for t in range(n_frames): # prior distribution with rnn information mu_p_t, scale_p_t = self.get_prior(h) prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) prior_sample_t = prior_t.rsample() h = self.temporal(prior_sample_t, h) z[:, t, :] = prior_sample_t cond = {} cond['z'] = z cond['f0_hz'] = f0_hz y_params = self.decode(cond) params = synth.fill_params(y_params, cond) resyn_audio, outputs = synth(params, n_samples) return params, resyn_audio
def trajectory(self, current_state): ''' Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not. Final output looks like: [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)] ''' output_history = [] while True: mu, sigma = self.forward(current_state) distribution = Independent(Normal(mu, sigma), 1) picked_action = distribution.rsample() action = picked_action.detach() #print(action) new_state, reward = self.env.state_and_reward( current_state, action ) #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment #Attempting this output_history.append( (current_state, action, reward, distribution.log_prob(action))) if new_state is None: #essentially, you died or finished your trajectory break else: current_state = new_state return output_history
class TanhNormal(Distribution): """Copied from Kaixhi""" def __init__(self, loc, scale): super().__init__() self.normal = Independent(Normal(loc, scale), 1) def sample(self): return torch.tanh(self.normal.sample()) # samples with re-parametrization trick (differentiable) def rsample(self): return torch.tanh(self.normal.rsample()) # Calculates log probability of value using the change-of-variables technique # (uses log1p = log(1 + x) for extra numerical stability) def log_prob(self, value): inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2 # artanh(y) # log p(f^-1(y)) + log |det(J(f^-1(y)))| return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) + 1e-6).sum(dim=1) @property def mean(self): return torch.tanh(self.normal.mean) def get_std(self): return self.normal.stddev
def forward(self, x): kl = torch.zeros(1).to(device) z = 0. # Unet encoder result x_enc = self.unet_encoder(x) # VAE regularisation if not self.unet: emb = self.vae_encoder(x) # Split encoder outputs into a mean and variance vector mu, log_var = torch.chunk(emb, 2, dim=1) # Make sure that the log variance is positive log_var = softplus(log_var) sigma = torch.exp(log_var / 2) # Instantiate a diagonal Gaussian with mean=mu, std=sigma # This is the approximate latent distribution q(z|x) posterior = Independent(Normal(loc=mu, scale=sigma), 1) z = posterior.rsample() # Instantiate a standard Gaussian with mean=mu_0, std=sigma_0 # This is the prior distribution p(z) prior = Independent(Normal(loc=self.mu_0, scale=self.sigma_0), 1) # Estimate the KLD between q(z|x)|| p(z) kl = KLD(posterior, prior).sum() # Outputs for MSE xHat = self.decoder(x_enc, z) return kl, xHat
class IndependentNormal(Distribution): arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def generate(self, synth, h_0, f0_hz, attributes, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial seed of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] attributes: attributes [batch, n_frames, attribute_size] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ if len(h_0.shape) == 2: h = h_0[None, :, :] # 1, batch, latent_dims else: h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] for i in range(n_frames): cond = {} output = torch.cat([h.permute(1, 0, 2), attributes], dim=-1) mu, logscale = self.psi_p(output, output) scale = logscale.exp() prior = Independent(Normal(mu, scale), 1) prior_sample = prior.rsample() cond['z'] = prior_sample cond['f0_hz'] = f0_hz[:, i, :].unsqueeze(1) cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0 # generate x y = self.decode(cond) params = synth.fill_params(y, cond) params_list.append(params) x_tilde, _outputs = synth( params, n_samples=n_fft) # write exactly one frame cond['audio'] = x_tilde # encode cond = self.encoder(cond) z_enc = cond['z'] # get psi_q mu, logscale = self.psi_q(z_enc, z_enc) psi = torch.cat([mu, logscale], dim=-1) # temporal model temp_q, h = self.temporal_q(psi, h) # one off param_names = params_list[0].keys() final_params = {} for pn in param_names: #cat over frames final_params[pn] = torch.cat([par[pn] for par in params_list], dim=1) final_audio, _outputs = synth(final_params, n_samples=n_samples) return final_params, final_audio
def forward(self, obs, act=None, deterministic=False): # Optionally pass in an action to get the log_prob of that action mu = self.mu_layer(obs) std = torch.exp(self.log_std_layer) pi = Independent(Normal(mu, std), 1) if act is None: act = pi.mean if deterministic else pi.rsample() log_prob = pi.log_prob(act) return pi, act, log_prob
def generate(self, synth, h_0, f0_hz, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial state of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ h = h_0 n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames) # needs to have same dimension as z params_list = [] z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h.device) for t in range(n_frames): h_mu, h_scale = self.h_process(h, h) mu_t, logscale_t = self.psi_p(h_mu, h_scale) # [batch, latent_size] scale_t = logscale_t.exp() prior_t = Independent(Normal(mu_t, scale_t), 1) prior_sample_t = prior_t.rsample() cond = {} z[:, t, :] = prior_sample_t cond['z'] = prior_sample_t.unsqueeze(1) cond['f0_hz'] = f0_hz[:, t, :].unsqueeze(1) cond['f0_scaled'] = hz_to_midi(cond['f0_hz']) / 127.0 # generate x y = self.decode(cond) params = synth.fill_params(y, cond) params_list.append(params) x_tilde, _outputs = synth( params, n_samples=n_fft) # write exactly one frame cond['audio'] = x_tilde # encode cond = self.encoder(cond) z_enc = cond['z'].squeeze(1) # get psi_q mu, logscale = self.psi_q(z_enc, z_enc) rnn_input = torch.cat([mu, logscale, prior_sample_t], dim=-1) # temporal model h = self.temporal_q(rnn_input, h) # one off cond = {} cond['z'] = z cond['f0_hz'] = f0_hz y_params = self.decode(cond) params = synth.fill_params(y_params, cond) resyn_audio, outputs = synth(params, n_samples) return params, resyn_audio
def act(self, obs, deterministic=False): action_mean = self.forward(obs) normal = Normal(action_mean, torch.exp(self.log_scale)) dist = Independent(normal, 1) if deterministic: action = action_mean else: action = dist.rsample() action_logprobs = dist.log_prob(torch.squeeze(action)) return action, action_logprobs
def latent(self, conditioning): z_enc = conditioning['z'] batch_size, num_frames, _ = z_enc.shape mu = self.linear_mu(z_enc) log_var = self.linear_logvar(z_enc) eps = torch.randn_like(mu).detach().to(mu.device) posterior = Independent(Normal(mu, log_var.exp().sqrt()), 1) posterior_sample = posterior.rsample() # Compute KL divergence prior = Independent(Normal(torch.zeros_like(mu), torch.ones_like(log_var)), 1) kl_div = torch.mean(kl_divergence(posterior, prior)) return posterior_sample, kl_div, [mu, log_var.exp().sqrt()]
def temporal_model_step(self, z_enc_t, h, attribute=None): """ generate z_t autoregressively """ # mix with temporal info mu_z_t, scale_z_t = self.get_posterior(h, z_enc_t) scale_z_t = scale_z_t + 1e-4 # minimum # final posterior distribution with rnn information posterior_t = Independent(Normal(mu_z_t, scale_z_t), 1) posterior_sample_t = posterior_t.rsample() # prior distribution with rnn information mu_p_t, scale_p_t = self.get_prior(h) scale_p_t = scale_p_t + 1e-4 # minimum prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) kl_div_t = torch.mean(kl_divergence(posterior_t, prior_t)) return posterior_sample_t, kl_div_t, mu_z_t, scale_z_t
def forward(self, obs, deterministic=False): mu = self.mu_layer(obs) log_std = self.log_std_layer(obs) std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() # Pre-squash distribution and sample pi_distribution = Independent(Normal(mu, std), 1) act = mu if deterministic else pi_distribution.rsample() log_prob = pi_distribution.log_prob(act) squashed_action = torch.tanh(act) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(axis=-1) return squashed_action, log_prob
def _mcvi_forward(self, x): if self.certain or not self.deterministic: x_mean = x if not isinstance(x, tuple) else x[0] x_var = x_mean * x_mean else: x_mean = x[0] x_var = x[1] W_var = self._get_var(self.W_logvar) bias_var = self._get_var(self.bias_logvar) z_mean = F.conv2d(x_mean, self.W, self.bias, self.stride, self.padding) z_var = F.conv2d(x_var, W_var, bias_var, self.stride, self.padding) dst = Independent(Normal(z_mean, z_var), 1) sample = dst.rsample() return sample, None
class VAE(nn.Module): def __init__(self, encoder, decoder, device=None): super().__init__() self.encoder = encoder self.decoder = decoder if device is None: self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") else: self.device = device def forward(self, x=None): bs = x.size(0) ls = self.encoder.latent_dims mu, sigma = self.encoder(x) self.pz = Independent(Normal(loc=torch.zeros(bs, ls).to(self.device), scale=torch.ones(bs, ls).to(self.device)), reinterpreted_batch_ndims=1) self.qz_x = Independent(Normal(loc=mu, scale=torch.exp(sigma)), reinterpreted_batch_ndims=1) self.z = self.qz_x.rsample() decoded = self.decoder(self.z) return decoded def compute_loss(self, x, y, scale_kl=False): px_z = Independent(ContinuousBernoulli(logits=y), reinterpreted_batch_ndims=3) px = px_z.log_prob(x) kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z) if scale_kl: kl = kl * scale_kl loss = -(px + kl).mean() return loss, kl.mean().item(), px.mean().item() def rmse(self, input, target): return torch.sqrt(F.mse_loss(input, target))
def _loss_vae(self, x): batch_size = x.size(0) encoder_output = self.encoder(x) pz = Independent(Normal(loc=torch.zeros(batch_size, self.latent_dim).to(self.device), scale=torch.ones(batch_size, self.latent_dim).to(self.device)), reinterpreted_batch_ndims=1) qz_x = Independent(Normal(loc=encoder_output[:, :self.latent_dim], scale=torch.exp(encoder_output[:, self.latent_dim:])), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = self.decoder(z) px_z = Independent(Bernoulli(logits=decoder_output), reinterpreted_batch_ndims=1) loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean() return loss, decoder_output
def generate(self, synth, h_0, f0_hz, attributes, enc_frame_setting='fine', n_samples=16000): """ synth: synth to generate audio h_0: initial state of RNN [batch, latent_dims] f0_hz: f0 conditioning of synth [batch, f0_n_frames, 1] attributes: attributes [batch, attribute_size] or [batch, n_frames, attribute_size] enc_frame_setting: fft/hop size n_samples: output audio length in samples """ n_fft, hop_length = get_window_hop(enc_frame_setting) n_frames = math.ceil((n_samples - n_fft) / hop_length) + 1 f0_hz = resample_frames(f0_hz, n_frames).to( h_0.device) # needs to have same dimension as z params_list = [] z = torch.zeros(h_0.shape[0], n_frames, self.latent_dims).to(h_0.device) if len(attributes.shape) == 2: attributes = attributes[:, None, :].expand(-1, n_frames, -1) # set up initial prior with attributes z_t = torch.zeros(h_0.shape[0], self.latent_dims).to(h_0.device) rnn_input = torch.cat([z_t, attributes[:, 0, :]], dim=-1) h = self.temporal(rnn_input, h_0) for t in range(n_frames): # prior distribution with rnn information mu_p_t, scale_p_t = self.get_prior(h) prior_t = Independent(Normal(mu_p_t, scale_p_t), 1) z_t = prior_t.rsample() rnn_input = torch.cat([z_t, attributes[:, t, :]], dim=-1) h = self.temporal(rnn_input, h) z[:, t, :] = z_t cond = {} cond['z'] = z cond['f0_hz'] = f0_hz y_params = self.decode(cond) params = synth.fill_params(y_params, cond) resyn_audio, outputs = synth(params, n_samples) return params, resyn_audio
def loss_vae_normal(x, encoder, decoder): batch_size = x.size(0) encoder_output = encoder(x) d = encoder_output.shape[1] // 2 pz_loc = F.sigmoid(torch.zeros(batch_size, d).to(device)) pz_scale = torch.ones(batch_size, d).to(device) pz = Independent(Normal(loc=pz_loc, scale=pz_scale), reinterpreted_batch_ndims=1) qz_x_loc = encoder_output[:, :d] qz_x_log_scale = encoder_output[:, d:] qz_x = Independent(Normal(loc=qz_x_loc, scale=qz_x_log_scale**2), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = decoder(z) optimal_sigma_observed = ((x - decoder_output)**2).mean( [0, 1, 2, 3], keepdim=True).sqrt() px_z = Independent(Normal(loc=decoder_output, scale=optimal_sigma_observed), reinterpreted_batch_ndims=3) elbo = (px_z.log_prob(x) - kl_divergence(qz_x, pz)).mean() return -elbo, decoder_output
class IndependentRescaledBeta(Distribution): arg_constraints = { 'concentration1': constraints.positive, 'concentration0': constraints.positive } support = constraints.interval(-1., 1.) has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): self.base_dist = Independent(RescaledBeta(concentration1, concentration0, validate_args), len(concentration1.shape) - 1, validate_args=validate_args) super(IndependentRescaledBeta, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def latent(self, conditioning): """ z_enc: [batch, frames, latent_dims] """ z_enc = conditioning['z'] batch_size, n_frames, _encdims = z_enc.shape mu_q, logscale_q = self.psi_q(z_enc, z_enc) # feed into temporal model h_0 = torch.randn(1, batch_size, self.latent_dims).to( z_enc.device) * 0.01 temp_q = self.temporal_latent_model(mu_q, logscale_q, h_0) # final posterior distribution with rnn information mu_z, scale_z = self.mix_with_temp(mu_q, logscale_q, temp_q) posterior = Independent(Normal(mu_z, scale_z), 1) posterior_sample = posterior.rsample() # prior distribution with rnn information mu, scale = self.psi_p(temp_q, temp_q) scale = scale.exp() prior = Independent(Normal(mu, scale), 1) # prior = Independent(Normal(torch.zeros_like(mu), torch.ones_like(scale)), 1) # sum over latent dim mean over batch kl_div = torch.mean(kl_divergence(posterior, prior)) return posterior_sample, kl_div, [mu_z, scale_z]
def _zero_mean_forward(self, x): if self.certain or not self.deterministic: x_mean = x if not isinstance(x, tuple) else x[0] x_var = x_mean * x_mean else: x_mean = x[0] x_var = x[1] W_var = torch.exp(self.log_alpha) * self.weight * self.weight z_mean = F.conv2d(x_mean, torch.zeros_like(self.weight), self.bias, self.stride, self.padding) z_var = F.conv2d(x_var, W_var, bias=None, stride=self.stride, padding=self.padding) if self.deterministic: return z_mean, z_var else: dst = Independent(Normal(z_mean, z_var), 1) sample = dst.rsample() return sample, None
def loss(self, x): """ returns 1. the avergave value of negative ELBO across the minibatch x 2. and the output of the decoder """ batch_size = x.size(0) encoder_output = self.encoder(x) pz = Independent(Normal(loc=torch.zeros(batch_size, self.z_dim).to(self.device), scale=torch.ones(batch_size, self.z_dim).to(self.device)), reinterpreted_batch_ndims=1) qz_x = Independent(Normal(loc=encoder_output[:, :self.z_dim], scale=torch.exp( encoder_output[:, self.z_dim:])), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = self.decoder(z) px_z = Independent(Bernoulli(logits=decoder_output), reinterpreted_batch_ndims=1) loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean() return loss, decoder_output
class TorchGaussianMixtureDistribution(TorchDistributionWrapper): @staticmethod def required_model_output_shape(action_space, model_config): return prod((2, model_config['custom_model_config']['num_gaussians']) + action_space.shape) def __init__(self, inputs: List[torch.Tensor], model: TorchModelV2): super(TorchDistributionWrapper, self).__init__(inputs, model) assert len(inputs.shape) == 2 self.batch_size = self.inputs.shape[0] self.num_gaussians = model.model_config['custom_model_config'][ 'num_gaussians'] self.monte_samples = model.model_config['custom_model_config'][ 'monte_samples'] inputs = torch.reshape(self.inputs, (self.batch_size, 2, self.num_gaussians, -1)) self.action_dim = inputs.shape[-1] assert not torch.isnan(inputs).any(), "Input nan aborting" self.means = inputs[:, 0, :, :] # batch_size x num_gaussians x action_dim self.sigmas = torch.exp( inputs[:, 1, :, :]) # batch_size x num_gaussians x action_dim self.cat = Categorical( torch.ones(self.batch_size, self.num_gaussians, device=inputs.device, requires_grad=False)) self.normals = Independent(Normal(self.means, self.sigmas), 1) def logp(self, actions: torch.Tensor): actions = actions.view( self.batch_size, 1, -1) # batch_size x 1 (broadcast to num gaussians) x action_dim mix_lps = self.cat.logits # batch_size x num_gaussians x action_dim normal_lps = self.normals.log_prob( actions) # batch_size x num_gaussians x action_dim assert not torch.isnan(mix_lps).any(), "output nan aborting" assert not torch.isnan(normal_lps).any(), "output nan aborting" return torch.logsumexp(mix_lps + normal_lps, dim=1) # reduce along num gaussians def deterministic_sample(self) -> torch.Tensor: self.last_sample = self.means[:, 0, :] # select the mode of the first gaussian return self.last_sample def __rsamples(self): """ Compute samples that can be differentiated through """ # Using reparameterization trick i.e. rsample normal_samples = self.normals.rsample( (self.monte_samples, )) # monte_samples x batch_size x num_gaussians x action_dim cat_samples = self.cat.sample( (self.monte_samples, )) # monte_samples x batch_size # First we need to expand cat so that it has the same dimension as normal samples cat_samples = cat_samples.reshape(self.monte_samples, -1, 1, 1).expand(-1, -1, -1, self.action_dim) # We select the normal distribution based on the outputs of # the categorical distribution return torch.gather(normal_samples, 2, cat_samples).squeeze( dim=2) # monte_samples x batch_size x action_dim def kl(self, q: ActionDistribution) -> torch.Tensor: """ KL(self || q) estimated with monte carlo sampling """ rsamples = self.__rsamples().unbind(0) log_ratios = torch.stack( [self.logp(rsample) - q.logp(rsample) for rsample in rsamples]) assert not torch.isnan(log_ratios).any(), "output nan aborting" return log_ratios.mean(0) def entropy(self) -> torch.Tensor: """ H(self) estimated with monte carlo sampling """ rsamples = self.__rsamples().unbind(0) log_ps = torch.stack([-self.logp(rsample) for rsample in rsamples]) assert not torch.isnan(log_ps).any(), "output nan aborting" return log_ps.mean(0) def sample(self): normal_samples = self.normals.sample( ) # batch_size x num_gaussians x action_dim cat_samples = self.cat.sample() # batch_size # First we need to expand cat so that it has the same dimension as normal samples cat_samples = cat_samples.view(-1, 1, 1).expand(-1, -1, self.action_dim) # We select the normal distribution based on the outputs of # the categorical distribution self.last_sample = torch.gather(normal_samples, 1, cat_samples).squeeze( dim=1) # batch_size x action_dim assert len( self.last_sample.shape) == 2, f"shape, {self.last_sample.shape}" return self.last_sample