def forward(self, observation: torch.Tensor, prev_action: torch.Tensor = None, prev_state: RSSMState = None): state = self.get_state_representation(observation, prev_action, prev_state) action, action_dist = self.policy(state) value = self.value_model(get_feat(state)) reward = self.reward_model(get_feat(state)) return action, action_dist, value, reward, state
def write_videos(self, observation, action, image_pred, post, step=None, n=4, t=25): """ observation shape T,N,C,H,W generates n rollouts with the model. For t time steps, observations are used to generate state representations. Then for time steps t+1:T, uses the state transition model. Outputs 3 different frames to video: ground truth, reconstruction, error """ lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) model = self.agent.model ground_truth = observation[:, :n] + 0.5 reconstruction = image_pred.mean[:t, :n] prev_state = post[t - 1, :n] prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n], prev_state) imagined = model.observation_decoder(get_feat(prior)).mean model = torch.cat((reconstruction, imagined), dim=0) + 0.5 error = (model - ground_truth + 1) / 2 # concatenate vertically on height dimension openl = torch.cat((ground_truth, model, error), dim=3) openl = openl.transpose(1, 0) # N,T,C,H,W video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
def policy(self, state: RSSMState): feat = get_feat(state) action_dist = self.action_decoder(feat) if self.action_dist == 'tanh_normal': if self.training: # use agent.train(bool) or agent.eval() action = action_dist.sample() else: action = action_dist.mode() else: # cannot propagate gradients with one hot distribution action = action_dist.sample() return action, action_dist
def policy(self, state: RSSMState): feat = get_feat(state) action_dist = self.action_decoder(feat) if self.action_dist == 'tanh_normal': if self.training: # use agent.train(bool) or agent.eval() action = action_dist.rsample() else: action = action_dist.mode() elif self.action_dist == 'one_hot': action = action_dist.sample() # This doesn't change the value, but gives us straight-through gradients action = action + action_dist.probs - action_dist.probs.detach() elif self.action_dist == 'relaxed_one_hot': action = action_dist.rsample() else: action = action_dist.sample() return action, action_dist
def model_loss(self, observation: torch.Tensor, prior: RSSMState, post: RSSMState, reward: torch.Tensor): """ Compute the model loss for a bunch of data. All vectors are [batch_t, batch_x, vector_dim] """ model = self.agent.model feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.clamp(div, -float('inf'), self.free_nats) model_loss = self.kl_scale * div + reward_loss + image_loss return model_loss
def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int): """ Compute the loss for a batch of data. This includes computing the model and reward losses on the given data, as well as using the dynamics model to generate additional rollouts, which are used for the actor and value components of the loss. :param samples: samples from replay :param sample_itr: sample iteration :param opt_itr: optimization iteration :return: FloatTensor containing the loss """ model = self.agent.model observation = samples.all_observation[: -1] # [t, t+batch_length+1] -> [t, t+batch_length] action = samples.all_action[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = samples.all_reward[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = reward.unsqueeze(2) done = samples.done done = done.unsqueeze(2) # Extract tensors from the Samples object # They all have the batch_t dimension first, but we'll put the batch_b dimension first. # Also, we convert all tensors to floats so they can be fed into our models. lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) # squeeze batch sizes to single batch dimension for imagination roll-out batch_size = batch_t * batch_b # normalize image observation = observation.type(self.type) / 255.0 - 0.5 # embed the image embed = model.observation_encoder(observation) prev_state = model.representation.initial_state(batch_b, device=action.device, dtype=action.dtype) # Rollout model by taking the same series of actions as the real model prior, post = model.rollout.rollout_representation( batch_t, embed, action, prev_state) # Flatten our data (so first dimension is batch_t * batch_b = batch_size) # since we're going to do a new rollout starting from each state visited in each batch. # Compute losses for each component of the model # Model Loss feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) pcont_loss = torch.tensor(0.) # placeholder if use_pcont = False if self.use_pcont: pcont_pred = model.pcont(feat) pcont_target = self.discount * (1 - done.float()) pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.max(div, div.new_full(div.size(), self.free_nats)) model_loss = self.kl_scale * div + reward_loss + image_loss if self.use_pcont: model_loss += self.pcont_scale * pcont_loss # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Actor Loss # remove gradients from previously calculated tensors with torch.no_grad(): if self.use_pcont: # "Last step could be terminal." Done in TF2 code, but unclear why flat_post = buffer_method(post[:-1, :], 'reshape', (batch_t - 1) * (batch_b), -1) else: flat_post = buffer_method(post, 'reshape', batch_size, -1) # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real. # imag_feat shape is [horizon, batch_t * batch_b, feature_size] with FreezeParameters(self.model_modules): imag_dist, _ = model.rollout.rollout_policy( self.horizon, model.policy, flat_post) # Use state features (deterministic and stochastic) to predict the image and reward imag_feat = get_feat( imag_dist) # [horizon, batch_t * batch_b, feature_size] # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode # If we want to use other distributions we'll have to fix this. # We calculate the target here so no grad necessary # freeze model parameters as only action model gradients needed with FreezeParameters(self.model_modules + self.value_modules): imag_reward = model.reward_model(imag_feat).mean value = model.value_model(imag_feat).mean # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters([model.pcont]): discount_arr = model.pcont(imag_feat).mean else: discount_arr = self.discount * torch.ones_like(imag_reward) returns = self.compute_return(imag_reward[:-1], value[:-1], discount_arr[:-1], bootstrap=value[-1], lambda_=self.discount_lambda) # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = torch.cat( [torch.ones_like(discount_arr[:1]), discount_arr[1:]]) discount = torch.cumprod(discount_arr[:-1], 0) actor_loss = -torch.mean(discount * returns) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Value Loss # remove gradients from previously calculated tensors with torch.no_grad(): value_feat = imag_feat[:-1].detach() value_discount = discount.detach() value_target = returns.detach() value_pred = model.value_model(value_feat) log_prob = value_pred.log_prob(value_target) value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2)) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # loss info with torch.no_grad(): prior_ent = torch.mean(prior_dist.entropy()) post_ent = torch.mean(post_dist.entropy()) loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent, post_ent, div, reward_loss, image_loss, pcont_loss) if self.log_video: if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0: self.write_videos(observation, action, image_pred, post, step=sample_itr, n=self.video_summary_b, t=self.video_summary_t) return model_loss, actor_loss, value_loss, loss_info
def loss(self, samples: Samples): """ Compute the loss for a batch of data. This includes computing the model and reward losses on the given data, as well as using the dynamics model to generate additional rollouts, which are used for the actor and value components of the loss. :param samples: rlpyt Samples object, containing a batch of data. All vectors are [timestep, batch, vector_dim] :return: FloatTensor containing the loss """ model = self.agent.model device = next(model.parameters()).device # Extract tensors from the Samples object # They all have the batch_t dimension first, but we'll put the batch_b dimension first. # Also, we convert all tensors to floats so they can be fed into our models. lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( samples.env.observation, 3) # squeeze batch sizes to single batch dimension for imagination roll-out batch_size = batch_t * batch_b observation = samples.env.observation.to(device) # normalize image observation = observation.type(self.type) / 255.0 - 0.5 # embed the image embed = model.observation_encoder(observation) # get action action = samples.agent.action # make actions one-hot action = to_onehot(action, model.action_size, dtype=self.type).to(device) reward = samples.env.reward.to(device) # if we want to continue the the agent state from the previous time steps, we can do it like so: # prev_state = samples.agent.agent_info.prev_state[0] prev_state = model.representation.initial_state(batch_b, device=device) # Rollout model by taking the same series of actions as the real model post, prior = model.rollout.rollout_representation( batch_t, embed, action, prev_state) # Flatten our data (so first dimension is batch_t * batch_b = batch_size) # since we're going to do a new rollout starting from each state visited in each batch. flat_post = RSSMState(mean=post.mean.reshape(batch_size, -1), std=post.std.reshape(batch_size, -1), stoch=post.stoch.reshape(batch_size, -1), deter=post.deter.reshape(batch_size, -1)) flat_action = action.reshape(batch_size, -1) # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real. # imag_feat shape is [horizon, batch_t * batch_b, feature_size] imag_dist, _ = model.rollout.rollout_policy(self.horizon, model.policy, flat_action, flat_post) # Use state features (deterministic and stochastic) to predict the image and reward imag_feat = get_feat( imag_dist) # [horizon, batch_t * batch_b, feature_size] # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode # If we want to use other distributions we'll have to fix this. imag_reward = model.reward_model(imag_feat).mean value = model.value_model(imag_feat).mean # Compute the exponential discounted sum of rewards discount_arr = self.discount * torch.ones_like(imag_reward) returns = self.compute_return(imag_reward[:-1], value[:-1], discount_arr[:-1], bootstrap=value[-1], lambda_=self.discount_lambda) discount = torch.cumprod(discount_arr[:-1], 1).detach() # Compute losses for each component of the model model_loss = self.model_loss(observation, prior, post, reward) actor_loss = self.actor_loss(discount, returns) value_loss = self.value_loss(imag_feat, discount, returns) loss = self.model_weight * model_loss + self.actor_weight * actor_loss + self.value_weight * value_loss return loss