def write_videos(self, observation, action, image_pred, post, step=None, n=4, t=25): """ observation shape T,N,C,H,W generates n rollouts with the model. For t time steps, observations are used to generate state representations. Then for time steps t+1:T, uses the state transition model. Outputs 3 different frames to video: ground truth, reconstruction, error """ lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) model = self.agent.model ground_truth = observation[:, :n] + 0.5 reconstruction = image_pred.mean[:t, :n] prev_state = post[t - 1, :n] prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n], prev_state) imagined = model.observation_decoder(get_feat(prior)).mean model = torch.cat((reconstruction, imagined), dim=0) + 0.5 error = (model - ground_truth + 1) / 2 # concatenate vertically on height dimension openl = torch.cat((ground_truth, model, error), dim=3) openl = openl.transpose(1, 0) # N,T,C,H,W video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
def write_env_videos(self, env, step=None): o = env.reset() # kangaroo done = False images = [] i=0 while not done: images.append(env.render()) o = torchify_buffer(o) o = buffer_to(o, self.agent.device) a = self.agent.model(o) step_results = env.step(a.action.cpu().detach().numpy()) o = step_results.observation done = step_results.done if hasattr(step_results.env_info, "traj_done") and step_results.env_info.traj_done: done = True i += 1 openl = torch.stack([torch.tensor(i) for i in images], dim=0).permute(0, 3, 1, 2).unsqueeze(0) # N,T,C,H,W video_summary('videos/real_rollout', openl, step)