def model_loss(self, observation: torch.Tensor, prior: RSSMState, post: RSSMState, reward: torch.Tensor): """ Compute the model loss for a bunch of data. All vectors are [batch_t, batch_x, vector_dim] """ model = self.agent.model feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.clamp(div, -float('inf'), self.free_nats) model_loss = self.kl_scale * div + reward_loss + image_loss return model_loss
def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int): """ Compute the loss for a batch of data. This includes computing the model and reward losses on the given data, as well as using the dynamics model to generate additional rollouts, which are used for the actor and value components of the loss. :param samples: samples from replay :param sample_itr: sample iteration :param opt_itr: optimization iteration :return: FloatTensor containing the loss """ model = self.agent.model observation = samples.all_observation[: -1] # [t, t+batch_length+1] -> [t, t+batch_length] action = samples.all_action[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = samples.all_reward[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = reward.unsqueeze(2) done = samples.done done = done.unsqueeze(2) # Extract tensors from the Samples object # They all have the batch_t dimension first, but we'll put the batch_b dimension first. # Also, we convert all tensors to floats so they can be fed into our models. lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) # squeeze batch sizes to single batch dimension for imagination roll-out batch_size = batch_t * batch_b # normalize image observation = observation.type(self.type) / 255.0 - 0.5 # embed the image embed = model.observation_encoder(observation) prev_state = model.representation.initial_state(batch_b, device=action.device, dtype=action.dtype) # Rollout model by taking the same series of actions as the real model prior, post = model.rollout.rollout_representation( batch_t, embed, action, prev_state) # Flatten our data (so first dimension is batch_t * batch_b = batch_size) # since we're going to do a new rollout starting from each state visited in each batch. # Compute losses for each component of the model # Model Loss feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) pcont_loss = torch.tensor(0.) # placeholder if use_pcont = False if self.use_pcont: pcont_pred = model.pcont(feat) pcont_target = self.discount * (1 - done.float()) pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.max(div, div.new_full(div.size(), self.free_nats)) model_loss = self.kl_scale * div + reward_loss + image_loss if self.use_pcont: model_loss += self.pcont_scale * pcont_loss # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Actor Loss # remove gradients from previously calculated tensors with torch.no_grad(): if self.use_pcont: # "Last step could be terminal." Done in TF2 code, but unclear why flat_post = buffer_method(post[:-1, :], 'reshape', (batch_t - 1) * (batch_b), -1) else: flat_post = buffer_method(post, 'reshape', batch_size, -1) # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real. # imag_feat shape is [horizon, batch_t * batch_b, feature_size] with FreezeParameters(self.model_modules): imag_dist, _ = model.rollout.rollout_policy( self.horizon, model.policy, flat_post) # Use state features (deterministic and stochastic) to predict the image and reward imag_feat = get_feat( imag_dist) # [horizon, batch_t * batch_b, feature_size] # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode # If we want to use other distributions we'll have to fix this. # We calculate the target here so no grad necessary # freeze model parameters as only action model gradients needed with FreezeParameters(self.model_modules + self.value_modules): imag_reward = model.reward_model(imag_feat).mean value = model.value_model(imag_feat).mean # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters([model.pcont]): discount_arr = model.pcont(imag_feat).mean else: discount_arr = self.discount * torch.ones_like(imag_reward) returns = self.compute_return(imag_reward[:-1], value[:-1], discount_arr[:-1], bootstrap=value[-1], lambda_=self.discount_lambda) # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = torch.cat( [torch.ones_like(discount_arr[:1]), discount_arr[1:]]) discount = torch.cumprod(discount_arr[:-1], 0) actor_loss = -torch.mean(discount * returns) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Value Loss # remove gradients from previously calculated tensors with torch.no_grad(): value_feat = imag_feat[:-1].detach() value_discount = discount.detach() value_target = returns.detach() value_pred = model.value_model(value_feat) log_prob = value_pred.log_prob(value_target) value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2)) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # loss info with torch.no_grad(): prior_ent = torch.mean(prior_dist.entropy()) post_ent = torch.mean(post_dist.entropy()) loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent, post_ent, div, reward_loss, image_loss, pcont_loss) if self.log_video: if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0: self.write_videos(observation, action, image_pred, post, step=sample_itr, n=self.video_summary_b, t=self.video_summary_t) return model_loss, actor_loss, value_loss, loss_info