Exemplo n.º 1
0
 def forward(self,
             observation: torch.Tensor,
             prev_action: torch.Tensor = None,
             prev_state: RSSMState = None):
     state = self.get_state_representation(observation, prev_action,
                                           prev_state)
     action, action_dist = self.policy(state)
     value = self.value_model(get_feat(state))
     reward = self.reward_model(get_feat(state))
     return action, action_dist, value, reward, state
Exemplo n.º 2
0
    def write_videos(self,
                     observation,
                     action,
                     image_pred,
                     post,
                     step=None,
                     n=4,
                     t=25):
        """
        observation shape T,N,C,H,W
        generates n rollouts with the model.
        For t time steps, observations are used to generate state representations.
        Then for time steps t+1:T, uses the state transition model.
        Outputs 3 different frames to video: ground truth, reconstruction, error
        """
        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            observation, 3)
        model = self.agent.model
        ground_truth = observation[:, :n] + 0.5
        reconstruction = image_pred.mean[:t, :n]

        prev_state = post[t - 1, :n]
        prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n],
                                                 prev_state)
        imagined = model.observation_decoder(get_feat(prior)).mean
        model = torch.cat((reconstruction, imagined), dim=0) + 0.5
        error = (model - ground_truth + 1) / 2
        # concatenate vertically on height dimension
        openl = torch.cat((ground_truth, model, error), dim=3)
        openl = openl.transpose(1, 0)  # N,T,C,H,W
        video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
Exemplo n.º 3
0
 def policy(self, state: RSSMState):
     feat = get_feat(state)
     action_dist = self.action_decoder(feat)
     if self.action_dist == 'tanh_normal':
         if self.training:  # use agent.train(bool) or agent.eval()
             action = action_dist.sample()
         else:
             action = action_dist.mode()
     else:
         # cannot propagate gradients with one hot distribution
         action = action_dist.sample()
     return action, action_dist
Exemplo n.º 4
0
 def policy(self, state: RSSMState):
     feat = get_feat(state)
     action_dist = self.action_decoder(feat)
     if self.action_dist == 'tanh_normal':
         if self.training:  # use agent.train(bool) or agent.eval()
             action = action_dist.rsample()
         else:
             action = action_dist.mode()
     elif self.action_dist == 'one_hot':
         action = action_dist.sample()
         # This doesn't change the value, but gives us straight-through gradients
         action = action + action_dist.probs - action_dist.probs.detach()
     elif self.action_dist == 'relaxed_one_hot':
         action = action_dist.rsample()
     else:
         action = action_dist.sample()
     return action, action_dist
Exemplo n.º 5
0
 def model_loss(self, observation: torch.Tensor, prior: RSSMState,
                post: RSSMState, reward: torch.Tensor):
     """
     Compute the model loss for a bunch of data. All vectors are [batch_t, batch_x, vector_dim]
     """
     model = self.agent.model
     feat = get_feat(post)
     image_pred = model.observation_decoder(feat)
     reward_pred = model.reward_model(feat)
     reward_loss = -torch.mean(reward_pred.log_prob(reward))
     image_loss = -torch.mean(image_pred.log_prob(observation))
     prior_dist = get_dist(prior)
     post_dist = get_dist(post)
     div = torch.mean(
         torch.distributions.kl.kl_divergence(post_dist, prior_dist))
     div = torch.clamp(div, -float('inf'), self.free_nats)
     model_loss = self.kl_scale * div + reward_loss + image_loss
     return model_loss
Exemplo n.º 6
0
    def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int):
        """
        Compute the loss for a batch of data.  This includes computing the model and reward losses on the given data,
        as well as using the dynamics model to generate additional rollouts, which are used for the actor and value
        components of the loss.
        :param samples: samples from replay
        :param sample_itr: sample iteration
        :param opt_itr: optimization iteration
        :return: FloatTensor containing the loss
        """
        model = self.agent.model

        observation = samples.all_observation[:
                                              -1]  # [t, t+batch_length+1] -> [t, t+batch_length]
        action = samples.all_action[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = samples.all_reward[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = reward.unsqueeze(2)
        done = samples.done
        done = done.unsqueeze(2)

        # Extract tensors from the Samples object
        # They all have the batch_t dimension first, but we'll put the batch_b dimension first.
        # Also, we convert all tensors to floats so they can be fed into our models.

        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            observation, 3)
        # squeeze batch sizes to single batch dimension for imagination roll-out
        batch_size = batch_t * batch_b

        # normalize image
        observation = observation.type(self.type) / 255.0 - 0.5
        # embed the image
        embed = model.observation_encoder(observation)

        prev_state = model.representation.initial_state(batch_b,
                                                        device=action.device,
                                                        dtype=action.dtype)
        # Rollout model by taking the same series of actions as the real model
        prior, post = model.rollout.rollout_representation(
            batch_t, embed, action, prev_state)
        # Flatten our data (so first dimension is batch_t * batch_b = batch_size)
        # since we're going to do a new rollout starting from each state visited in each batch.

        # Compute losses for each component of the model

        # Model Loss
        feat = get_feat(post)
        image_pred = model.observation_decoder(feat)
        reward_pred = model.reward_model(feat)
        reward_loss = -torch.mean(reward_pred.log_prob(reward))
        image_loss = -torch.mean(image_pred.log_prob(observation))
        pcont_loss = torch.tensor(0.)  # placeholder if use_pcont = False
        if self.use_pcont:
            pcont_pred = model.pcont(feat)
            pcont_target = self.discount * (1 - done.float())
            pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target))
        prior_dist = get_dist(prior)
        post_dist = get_dist(post)
        div = torch.mean(
            torch.distributions.kl.kl_divergence(post_dist, prior_dist))
        div = torch.max(div, div.new_full(div.size(), self.free_nats))
        model_loss = self.kl_scale * div + reward_loss + image_loss
        if self.use_pcont:
            model_loss += self.pcont_scale * pcont_loss

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Actor Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            if self.use_pcont:
                # "Last step could be terminal." Done in TF2 code, but unclear why
                flat_post = buffer_method(post[:-1, :], 'reshape',
                                          (batch_t - 1) * (batch_b), -1)
            else:
                flat_post = buffer_method(post, 'reshape', batch_size, -1)
        # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real.
        # imag_feat shape is [horizon, batch_t * batch_b, feature_size]
        with FreezeParameters(self.model_modules):
            imag_dist, _ = model.rollout.rollout_policy(
                self.horizon, model.policy, flat_post)

        # Use state features (deterministic and stochastic) to predict the image and reward
        imag_feat = get_feat(
            imag_dist)  # [horizon, batch_t * batch_b, feature_size]
        # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode
        # If we want to use other distributions we'll have to fix this.
        # We calculate the target here so no grad necessary

        # freeze model parameters as only action model gradients needed
        with FreezeParameters(self.model_modules + self.value_modules):
            imag_reward = model.reward_model(imag_feat).mean
            value = model.value_model(imag_feat).mean
        # Compute the exponential discounted sum of rewards
        if self.use_pcont:
            with FreezeParameters([model.pcont]):
                discount_arr = model.pcont(imag_feat).mean
        else:
            discount_arr = self.discount * torch.ones_like(imag_reward)
        returns = self.compute_return(imag_reward[:-1],
                                      value[:-1],
                                      discount_arr[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self.discount_lambda)
        # Make the top row 1 so the cumulative product starts with discount^0
        discount_arr = torch.cat(
            [torch.ones_like(discount_arr[:1]), discount_arr[1:]])
        discount = torch.cumprod(discount_arr[:-1], 0)
        actor_loss = -torch.mean(discount * returns)

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Value Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            value_feat = imag_feat[:-1].detach()
            value_discount = discount.detach()
            value_target = returns.detach()
        value_pred = model.value_model(value_feat)
        log_prob = value_pred.log_prob(value_target)
        value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2))

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # loss info
        with torch.no_grad():
            prior_ent = torch.mean(prior_dist.entropy())
            post_ent = torch.mean(post_dist.entropy())
            loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent,
                                 post_ent, div, reward_loss, image_loss,
                                 pcont_loss)

            if self.log_video:
                if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0:
                    self.write_videos(observation,
                                      action,
                                      image_pred,
                                      post,
                                      step=sample_itr,
                                      n=self.video_summary_b,
                                      t=self.video_summary_t)

        return model_loss, actor_loss, value_loss, loss_info
Exemplo n.º 7
0
    def loss(self, samples: Samples):
        """
        Compute the loss for a batch of data.  This includes computing the model and reward losses on the given data,
        as well as using the dynamics model to generate additional rollouts, which are used for the actor and value
        components of the loss.
        :param samples: rlpyt Samples object, containing a batch of data.  All vectors are [timestep, batch, vector_dim]
        :return: FloatTensor containing the loss
        """
        model = self.agent.model
        device = next(model.parameters()).device

        # Extract tensors from the Samples object
        # They all have the batch_t dimension first, but we'll put the batch_b dimension first.
        # Also, we convert all tensors to floats so they can be fed into our models.

        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            samples.env.observation, 3)
        # squeeze batch sizes to single batch dimension for imagination roll-out
        batch_size = batch_t * batch_b

        observation = samples.env.observation.to(device)
        # normalize image
        observation = observation.type(self.type) / 255.0 - 0.5
        # embed the image
        embed = model.observation_encoder(observation)

        # get action
        action = samples.agent.action
        # make actions one-hot
        action = to_onehot(action, model.action_size,
                           dtype=self.type).to(device)

        reward = samples.env.reward.to(device)

        # if we want to continue the the agent state from the previous time steps, we can do it like so:
        # prev_state = samples.agent.agent_info.prev_state[0]
        prev_state = model.representation.initial_state(batch_b, device=device)
        # Rollout model by taking the same series of actions as the real model
        post, prior = model.rollout.rollout_representation(
            batch_t, embed, action, prev_state)
        # Flatten our data (so first dimension is batch_t * batch_b = batch_size)
        # since we're going to do a new rollout starting from each state visited in each batch.
        flat_post = RSSMState(mean=post.mean.reshape(batch_size, -1),
                              std=post.std.reshape(batch_size, -1),
                              stoch=post.stoch.reshape(batch_size, -1),
                              deter=post.deter.reshape(batch_size, -1))
        flat_action = action.reshape(batch_size, -1)
        # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real.
        # imag_feat shape is [horizon, batch_t * batch_b, feature_size]
        imag_dist, _ = model.rollout.rollout_policy(self.horizon, model.policy,
                                                    flat_action, flat_post)

        # Use state features (deterministic and stochastic) to predict the image and reward
        imag_feat = get_feat(
            imag_dist)  # [horizon, batch_t * batch_b, feature_size]
        # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode
        # If we want to use other distributions we'll have to fix this.
        imag_reward = model.reward_model(imag_feat).mean
        value = model.value_model(imag_feat).mean

        # Compute the exponential discounted sum of rewards
        discount_arr = self.discount * torch.ones_like(imag_reward)
        returns = self.compute_return(imag_reward[:-1],
                                      value[:-1],
                                      discount_arr[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self.discount_lambda)
        discount = torch.cumprod(discount_arr[:-1], 1).detach()

        # Compute losses for each component of the model
        model_loss = self.model_loss(observation, prior, post, reward)
        actor_loss = self.actor_loss(discount, returns)
        value_loss = self.value_loss(imag_feat, discount, returns)
        loss = self.model_weight * model_loss + self.actor_weight * actor_loss + self.value_weight * value_loss
        return loss