Пример #1
0
    def write_videos(self,
                     observation,
                     action,
                     image_pred,
                     post,
                     step=None,
                     n=4,
                     t=25):
        """
        observation shape T,N,C,H,W
        generates n rollouts with the model.
        For t time steps, observations are used to generate state representations.
        Then for time steps t+1:T, uses the state transition model.
        Outputs 3 different frames to video: ground truth, reconstruction, error
        """
        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            observation, 3)
        model = self.agent.model
        ground_truth = observation[:, :n] + 0.5
        reconstruction = image_pred.mean[:t, :n]

        prev_state = post[t - 1, :n]
        prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n],
                                                 prev_state)
        imagined = model.observation_decoder(get_feat(prior)).mean
        model = torch.cat((reconstruction, imagined), dim=0) + 0.5
        error = (model - ground_truth + 1) / 2
        # concatenate vertically on height dimension
        openl = torch.cat((ground_truth, model, error), dim=3)
        openl = openl.transpose(1, 0)  # N,T,C,H,W
        video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
Пример #2
0
 def write_env_videos(self, env, step=None):
     o = env.reset()  # kangaroo
     done = False
     images = []
     i=0
     while not done:
         images.append(env.render())
         o = torchify_buffer(o)
         o = buffer_to(o, self.agent.device)
         a = self.agent.model(o)
         step_results = env.step(a.action.cpu().detach().numpy())
         o = step_results.observation
         done = step_results.done
         if hasattr(step_results.env_info, "traj_done") and step_results.env_info.traj_done:
             done = True
         i += 1
     openl = torch.stack([torch.tensor(i) for i in images], dim=0).permute(0, 3, 1, 2).unsqueeze(0) # N,T,C,H,W
     video_summary('videos/real_rollout', openl, step)