def full_seq_forward(self, inputs): if 'model_enc_seq' in inputs: enc_seq_1 = inputs.model_enc_seq[:, 1:] if self._hp.train_im0_enc and 'enc_traj_seq' in inputs: enc_seq_0 = inputs.enc_traj_seq.reshape( inputs.enc_traj_seq.shape[:2] + (self._hp.nz_enc, ))[:, :-1] enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]] else: enc_seq_0 = inputs.model_enc_seq[:, :-1] else: enc_seq = batch_apply(self.encoder, inputs.traj_seq) enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:] if self.detach_enc: enc_seq_0 = enc_seq_0.detach() enc_seq_1 = enc_seq_1.detach() # TODO quite sure the concatenation is automatic actions_pred = batch_apply(self.action_pred, torch.cat([enc_seq_0, enc_seq_1], dim=2)) output = AttrDict() output.actions = actions_pred #remove_spatial(actions_pred) if 'actions' in inputs: output.action_targets = inputs.actions output.pad_mask = inputs.pad_mask return output
def make_traj(self, agent_data, obs, policy_out): traj = AttrDict() if not self.do_not_save_images: traj.images = obs['images'] traj.states = obs['state'] action_list = [action['actions'] for action in policy_out] traj.actions = np.stack(action_list, 0) traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions) traj = pad_traj_timesteps(traj, self.max_num_actions) if 'robosuite_xml' in obs: traj.robosuite_xml = obs['robosuite_xml'][0] if 'robosuite_env_name' in obs: traj.robosuite_env_name = obs['robosuite_env_name'][0] if 'robosuite_full_state' in obs: traj.robosuite_full_state = obs['robosuite_full_state'] # minimal state that contains all information to position entities in the env if 'regression_state' in obs: traj.regression_state = obs['regression_state'] return traj