def choose_action(self, observation): mu, sigma = self.actor.forward(observation)#.to(self.actor.device) sigma = T.exp(sigma) action_probs = Independent(Normal(mu, sigma),1) probs = action_probs.sample() self.log_probs = action_probs.log_prob(probs).to(self.actor.device) return probs
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def forward(self, input, kl_coef): b, m, train_m, test_m = input mean, std = self.ode_rnn(input) # (batch_size, LO_hidden_size) * 2 d = Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device'])) r = d.sample(mean.shape).squeeze(-1) z0 = mean + r * std z_out = odeint(self.ode_func, z0, b[0, :, 0], rtol = self.param['rtol'], atol = self.param['atol']) # (num_time_points, batch_size, LO_hidden_size) z_out = z_out.permute(1, 0, 2) output = self.output_output(z_out).squeeze(2) z0_distr = Normal(mean, std) kl_div = kl_divergence(z0_distr, Normal(torch.tensor([0.0], device = self.param['device']), torch.tensor([1.0], device = self.param['device']))) kl_div = kl_div.mean(axis = 1) masked_output = output[test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points'])) target = b[:, :, 1][test_m.bool()].reshape(self.param['batch_size'], (self.param['total_points'] - self.param['obs_points'])) gaussian = Independent(Normal(loc = masked_output, scale = self.param['obsrv_std']), 1) log_prob = gaussian.log_prob(target) likelihood = log_prob / output.shape[1] loss = -torch.logsumexp(likelihood - kl_coef * kl_div, 0) mse = self.mse_func(masked_output, target) return loss, mse, masked_output
class IndependentNormal(Distribution): arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def get_true_posterior_samples_linear_gaussian_uniform_prior( observation: torch.Tensor, prior: Independent, num_samples: int = 1000, std=1, ): observation = utils.torchutils.atleast_2d(observation) assert observation.ndim == 2, "needs batch dimension in observation" mean = observation event_shape = mean.shape[1] posterior = MultivariateNormal( loc=mean, covariance_matrix=std * torch.eye(event_shape) ) # generate samples from ND Gaussian truncated by prior support num_remaining = num_samples samples = [] while num_remaining > 0: candidate_samples = posterior.sample(sample_shape=(num_remaining,)) is_in_prior = torch.isfinite(prior.log_prob(candidate_samples)) # accept if in prior if is_in_prior.sum(): samples.append( candidate_samples[is_in_prior,] ) num_remaining -= is_in_prior.sum().item() return torch.cat(samples)
def test_independent_expand(self): for Dist, params in EXAMPLES: for param in params: base_dist = Dist(**param) for reinterpreted_batch_ndims in range( len(base_dist.batch_shape) + 1): for s in [ torch.Size(), torch.Size((2, )), torch.Size((2, 3)) ]: indep_dist = Independent(base_dist, reinterpreted_batch_ndims) expanded_shape = s + indep_dist.batch_shape expanded = indep_dist.expand(expanded_shape) expanded_sample = expanded.sample() expected_shape = expanded_shape + indep_dist.event_shape self.assertEqual(expanded_sample.shape, expected_shape) self.assertEqual( expanded.log_prob(expanded_sample), indep_dist.log_prob(expanded_sample), ) self.assertEqual(expanded.event_shape, indep_dist.event_shape) self.assertEqual(expanded.batch_shape, expanded_shape)
def trajectory(self, current_state): ''' Maybe this implementation doesn't utilize GPUs very well, but I have no clue or not. Final output looks like: [(s_0, a_0, r_0), ..., (s_L, a_L, r_l)] ''' output_history = [] while True: mu, sigma = self.forward(current_state) distribution = Independent(Normal(mu, sigma), 1) picked_action = distribution.rsample() action = picked_action.detach() #print(action) new_state, reward = self.env.state_and_reward( current_state, action ) #Get the reward and the new state that the action in the environment resulted in. None if action caused death. TODO build in environment #Attempting this output_history.append( (current_state, action, reward, distribution.log_prob(action))) if new_state is None: #essentially, you died or finished your trajectory break else: current_state = new_state return output_history
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: act = logits[0] else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( (self.action_space.high - self.action_space.low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch(logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def forward(self, state, eval=False, with_log_prob=False): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) mu = self.fc3(x) log_sigma = self.fc4(x) # clip value of log_sigma, as was done in Haarnoja's implementation of SAC: # https://github.com/haarnoja/sac.git log_sigma = torch.clamp(log_sigma, -20.0, 2.0) sigma = torch.exp(log_sigma) distribution = Independent(Normal(mu, sigma), 1) if not eval: # use rsample() instead of sample(), as sample() does not allow back-propagation through params u = distribution.rsample() if with_log_prob: log_prob = distribution.log_prob(u) log_prob -= 2.0 * torch.sum( (np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u - F.softplus(-2.0 * u)), dim=1) else: log_prob = None else: u = mu log_prob = None # apply tanh so that the resulting action lies in (-1, 1)^D a = self.ctrl_range * torch.tanh(u) return a, log_prob
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias # __eps is used to avoid log of zero/negative number. y = self._action_scale * (1 - y.pow(2)) + self.__eps # Compute logprob from Gaussian, and then apply correction for Tanh squashing. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
class TanhNormal(Distribution): """Copied from Kaixhi""" def __init__(self, loc, scale): super().__init__() self.normal = Independent(Normal(loc, scale), 1) def sample(self): return torch.tanh(self.normal.sample()) # samples with re-parametrization trick (differentiable) def rsample(self): return torch.tanh(self.normal.rsample()) # Calculates log probability of value using the change-of-variables technique # (uses log1p = log(1 + x) for extra numerical stability) def log_prob(self, value): inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2 # artanh(y) # log p(f^-1(y)) + log |det(J(f^-1(y)))| return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) + 1e-6).sum(dim=1) @property def mean(self): return torch.tanh(self.normal.mean) def get_std(self): return self.normal.stddev
class VAE(nn.Module): def __init__(self, encoder, decoder, device=None): super().__init__() self.encoder = encoder self.decoder = decoder if device is None: self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") else: self.device = device def forward(self, x=None): bs = x.size(0) ls = self.encoder.latent_dims mu, sigma = self.encoder(x) self.pz = Independent(Normal(loc=torch.zeros(bs, ls).to(self.device), scale=torch.ones(bs, ls).to(self.device)), reinterpreted_batch_ndims=1) self.qz_x = Independent(Normal(loc=mu, scale=torch.exp(sigma)), reinterpreted_batch_ndims=1) self.z = self.qz_x.rsample() decoded = self.decoder(self.z) return decoded def compute_loss(self, x, y, scale_kl=False): px_z = Independent(ContinuousBernoulli(logits=y), reinterpreted_batch_ndims=3) px = px_z.log_prob(x) kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z) if scale_kl: kl = kl * scale_kl loss = -(px + kl).mean() return loss, kl.mean().item(), px.mean().item() def rmse(self, input, target): return torch.sqrt(F.mse_loss(input, target))
def _loss_vae(self, x): batch_size = x.size(0) encoder_output = self.encoder(x) pz = Independent(Normal(loc=torch.zeros(batch_size, self.latent_dim).to(self.device), scale=torch.ones(batch_size, self.latent_dim).to(self.device)), reinterpreted_batch_ndims=1) qz_x = Independent(Normal(loc=encoder_output[:, :self.latent_dim], scale=torch.exp(encoder_output[:, self.latent_dim:])), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = self.decoder(z) px_z = Independent(Bernoulli(logits=decoder_output), reinterpreted_batch_ndims=1) loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean() return loss, decoder_output
def forward(self, obs, act=None, deterministic=False): # Optionally pass in an action to get the log_prob of that action mu = self.mu_layer(obs) std = torch.exp(self.log_std_layer) pi = Independent(Normal(mu, std), 1) if act is None: act = pi.mean if deterministic else pi.rsample() log_prob = pi.log_prob(act) return pi, act, log_prob
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None): n_data_points = mu_2d.size()[-1] if n_data_points > 0: gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1) log_prob = gaussian.log_prob(data_2d) log_prob = log_prob / n_data_points else: log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() return log_prob
def act(self, obs, deterministic=False): action_mean = self.forward(obs) normal = Normal(action_mean, torch.exp(self.log_scale)) dist = Independent(normal, 1) if deterministic: action = action_mean else: action = dist.rsample() action_logprobs = dist.log_prob(torch.squeeze(action)) return action, action_logprobs
def calc_loglikelihood(inputs: torch.Tensor, outputs: torch.Tensor, sigma_prior: float): predicted = outputs.flatten(1) true_mu = inputs.flatten(1) base_normal = Normal( true_mu, torch.tensor(torch.ones_like(true_mu) * sigma_prior, dtype=torch.float32, device=true_mu.device)) mvn = Independent(base_normal, 1) llh = mvn.log_prob(predicted) return llh
def compute_loss(self, x, y, scale_kl=False): px_z = Independent(ContinuousBernoulli(logits=y), reinterpreted_batch_ndims=3) px = px_z.log_prob(x) kl = self.pz.log_prob(self.z) - self.qz_x.log_prob(self.z) if scale_kl: kl = kl * scale_kl loss = -(px + kl).mean() return loss, kl.mean().item(), px.mean().item()
def loss(self, x): """ returns 1. the avergave value of negative ELBO across the minibatch x 2. and the output of the decoder """ batch_size = x.size(0) encoder_output = self.encoder(x) pz = Independent(Normal(loc=torch.zeros(batch_size, self.z_dim).to(self.device), scale=torch.ones(batch_size, self.z_dim).to(self.device)), reinterpreted_batch_ndims=1) qz_x = Independent(Normal(loc=encoder_output[:, :self.z_dim], scale=torch.exp( encoder_output[:, self.z_dim:])), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = self.decoder(z) px_z = Independent(Bernoulli(logits=decoder_output), reinterpreted_batch_ndims=1) loss = -(px_z.log_prob(x) + pz.log_prob(z) - qz_x.log_prob(z)).mean() return loss, decoder_output
def get_action(self, obs): obs = torch.tensor(obs, dtype=torch.float).to(self.device) with torch.no_grad(): mu, sigma = self.pi(obs) act_distribution = Independent(Normal(mu, sigma), 1) action = act_distribution.sample() log_prob = act_distribution.log_prob(action) val = self.V(obs) action = action.cpu().numpy() log_prob = log_prob.cpu().numpy() val = val.cpu().numpy() return action, log_prob, val
def mdn_loss_fn(self, mu, sigma, y, epsilon=1e-9): # Non-vectorised version #result = torch.zeros(y.shape[0], self.n_gaussians).to(self.device) # for idx in range(self.n_gaussians): # gaussian = Independent(Normal(loc=mu[:, :, idx], scale=sigma), 1) # result_per_gaussian = gaussian.log_prob(y) # result[:, idx] = result_per_gaussian + self.pi.log() # return -torch.mean(torch.logsumexp(result, dim=1)) gaussian = Independent( Normal(loc=mu, scale=sigma.reshape(-1, self.n_outputs, 1).repeat(1, 1, mu.shape[2])), 0) result = gaussian.log_prob( y.reshape([-1, mu.shape[1], 1]).repeat(1, 1, self.n_gaussians)) result = torch.sum(result, dim=1) + self.pi.log() return -torch.mean(torch.logsumexp(result, dim=1))
def forward(self, obs, deterministic=False): mu = self.mu_layer(obs) log_std = self.log_std_layer(obs) std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() # Pre-squash distribution and sample pi_distribution = Independent(Normal(mu, std), 1) act = mu if deterministic else pi_distribution.rsample() log_prob = pi_distribution.log_prob(act) squashed_action = torch.tanh(act) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(axis=-1) return squashed_action, log_prob
def loss_vae_normal(x, encoder, decoder): batch_size = x.size(0) encoder_output = encoder(x) d = encoder_output.shape[1] // 2 pz_loc = F.sigmoid(torch.zeros(batch_size, d).to(device)) pz_scale = torch.ones(batch_size, d).to(device) pz = Independent(Normal(loc=pz_loc, scale=pz_scale), reinterpreted_batch_ndims=1) qz_x_loc = encoder_output[:, :d] qz_x_log_scale = encoder_output[:, d:] qz_x = Independent(Normal(loc=qz_x_loc, scale=qz_x_log_scale**2), reinterpreted_batch_ndims=1) z = qz_x.rsample() decoder_output = decoder(z) optimal_sigma_observed = ((x - decoder_output)**2).mean( [0, 1, 2, 3], keepdim=True).sqrt() px_z = Independent(Normal(loc=decoder_output, scale=optimal_sigma_observed), reinterpreted_batch_ndims=3) elbo = (px_z.log_prob(x) - kl_divergence(qz_x, pz)).mean() return -elbo, decoder_output
def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: base_dist = Dist(**param) x = base_dist.sample() base_log_prob_shape = base_dist.log_prob(x).shape for reinterpreted_batch_ndims in range( len(base_dist.batch_shape) + 1): indep_dist = Independent(base_dist, reinterpreted_batch_ndims) indep_log_prob_shape = base_log_prob_shape[:len( base_log_prob_shape) - reinterpreted_batch_ndims] self.assertEqual( indep_dist.log_prob(x).shape, indep_log_prob_shape) self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) if indep_dist.has_rsample: self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) try: self.assertEqual( indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape, ) self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape) except NotImplementedError: pass
def reinforce_loss(policy, episodes, init_std=1.0, min_std=1e-6, output_size=2 ): output = policy(episodes.observations.view((-1, *episodes.observation_shape))) min_log_std = math.log(min_std) sigma = nn.Parameter(torch.Tensor(output_size)) sigma.data.fill_(math.log(init_std)) scale = torch.exp(torch.clamp(sigma, min=min_log_std)) pi = Independent(Normal(loc=output, scale=scale), 1) log_probs = pi.log_prob(episodes.actions.view((-1, *episodes.action_shape))) log_probs = log_probs.view(len(episodes), episodes.batch_size) losses = -weighted_mean(log_probs * episodes.advantages, lengths=episodes.lengths) return losses.mean()
class IndependentRescaledBeta(Distribution): arg_constraints = { 'concentration1': constraints.positive, 'concentration0': constraints.positive } support = constraints.interval(-1., 1.) has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): self.base_dist = Independent(RescaledBeta(concentration1, concentration0, validate_args), len(concentration1.shape) - 1, validate_args=validate_args) super(IndependentRescaledBeta, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def loss_function(x_hat, x, q_z, z, epoch): if args.loss=='mixture': BCE = torch.mean(-log_mix_dep_Logistic_256(x, x_hat, average=True, n_comps=10)) if args.loss=='CE': x_hat = x_hat.view(-1, 3, 256, 64, 64) x_hat = x_hat.permute(0, 1, 3, 4, 2) x_hat = x_hat.contiguous() x_hat = x_hat.view(-1, 256) #x_hat = torch.round(256 * x_hat.view(-1, 256)) target = Variable(x.data.view(-1) * 255).long() BCE = loss(x_hat, target) #x = x.view(-1, x_hat.size(1)) #tensor = torch.ones(1) #p_x_dist = Beta(tensor.new_full((z.size(0), z_dim), 0.5).to(device), tensor.new_full((z.size(0), z_dim), 0.5).to(device)) z_sqrt = int(np.sqrt(z_dim)) if arch == 'resnet' or arch == 'convlin': p_x_dist = Independent(distri(torch.zeros(z.size(0), z_dim).to(device), torch.ones(z.size(0), z_dim).to(device)), 1) else: p_x_dist = Independent(distri(torch.zeros(z.size(0), 1, z_sqrt, z_sqrt).to(device), torch.ones(z.size(0), 1, z_sqrt, z_sqrt).to(device)), 1) one_third = round(args.epochs/3) if beta_final>=1: if epoch<=one_third: beta = (beta_final*epoch)/one_third else: beta = beta_final else: beta = 1 #BCE = torch.sum(-p_x.log_prob(x.view(x.size(0), x_dim**2))) KLD = torch.mean(q_z.log_prob(z) - p_x_dist.log_prob(z)) print(BCE, KLD, beta) return (BCE + beta*KLD), BCE, KLD
def __call__(self, x, out_keys=['action'], info={}, **kwargs): # Output dictionary out_policy = {} # Forward pass of feature networks to obtain features if self.recurrent: out_network = self.network(x=x, hidden_states=self.rnn_states, mask=info.get('mask', None)) features = out_network['output'] # Update the tracking of current RNN hidden states self.rnn_states = out_network['hidden_states'] else: features = self.network(x) # Forward pass through mean head to obtain mean values for Gaussian distribution mean = self.network.mean_head(features) # Obtain logvar based on the options if isinstance(self.network.logvar_head, nn.Linear): # linear layer, then do forward pass logvar = self.network.logvar_head(features) else: # either Tensor or nn.Parameter logvar = self.network.logvar_head # Expand as same shape as mean logvar = logvar.expand_as(mean) # Forward pass of value head to obtain value function if required if 'state_value' in out_keys: out_policy['state_value'] = self.network.value_head( features).squeeze(-1) # squeeze final single dim # Get std from logvar if self.std_style == 'exp': std = torch.exp(0.5 * logvar) elif self.std_style == 'softplus': std = F.softplus(logvar) # Lower bound threshould for std min_std = torch.full(std.size(), self.min_std).type_as(std).to(self.device) std = torch.max(std, min_std) # Create independent Gaussian distributions i.e. Diagonal Gaussian action_dist = Independent(Normal(loc=mean, scale=std), 1) # Sample action from the distribution (no gradient) # Do not use `rsample()`, it leads to zero gradient of mean head ! action = action_dist.sample() out_policy['action'] = action # Calculate log-probability of the sampled action if 'action_logprob' in out_keys: out_policy['action_logprob'] = action_dist.log_prob(action) # Calculate policy entropy conditioned on state if 'entropy' in out_keys: out_policy['entropy'] = action_dist.entropy() # Calculate policy perplexity i.e. exp(entropy) if 'perplexity' in out_keys: out_policy['perplexity'] = action_dist.perplexity() # sanity check for NaN if torch.any(torch.isnan(action)): while True: msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu' msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}' print(msg + msg2) # Constraint action in valid range out_policy['action'] = self.constraint_action(action) return out_policy
def normal_log_density(means, stds, actions): dist = Independent(Normal(means, stds), 1) return dist.log_prob(actions)
b, m, train_m, test_m = make_batch_mask(batch, param) input_tuple = (b, m, train_m, test_m) #tec = time.time() #print('Batch got in %.2f sec' % (tec - tic)) optimizer.zero_grad() output = model.forward(input_tuple) masked_output = output[test_m.bool()].reshape( param['batch_size'], (param['total_points'] - param['obs_points'])) target = b[:, :, 1][test_m.bool()].reshape( param['batch_size'], (param['total_points'] - param['obs_points'])) log_likelihood = torch.tensor(0.0) for i in range(masked_output.shape[0]): gaussian = Independent(Normal(masked_output[i], param['sigma']), 1) ll = gaussian.log_prob(target[i]) / masked_output.shape[1] log_likelihood += ll log_likelihood /= masked_output.shape[0] loss = -log_likelihood mse_loss = mse(masked_output, target) #tac = time.time() #print('Forward finished in %.2f sec' % (tac - tec)) loss.backward() #tuc = time.time() #print('Backward fininshed in %.2f sec' % (tuc - tac)) optimizer.step() toc = time.time() for k in range(param['figure_per_batch']): plt.clf()