def make_traj(self, agent_data, obs, policy_out): traj = AttrDict() if not self.do_not_save_images: traj.images = obs['images'] traj.states = obs['state'] action_list = [action['actions'] for action in policy_out] traj.actions = np.stack(action_list, 0) traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions) traj = pad_traj_timesteps(traj, self.max_num_actions) if 'robosuite_xml' in obs: traj.robosuite_xml = obs['robosuite_xml'][0] if 'robosuite_env_name' in obs: traj.robosuite_env_name = obs['robosuite_env_name'][0] if 'robosuite_full_state' in obs: traj.robosuite_full_state = obs['robosuite_full_state'] # minimal state that contains all information to position entities in the env if 'regression_state' in obs: traj.regression_state = obs['regression_state'] return traj
def rollout(self, state, goal_state, samples, rollout_len, prune=False): """Performs one model rollout.""" # prepare inputs batch_size = samples.shape[0] state, goal_state = state.repeat(batch_size, 0), goal_state.repeat(batch_size, 0) input_dict = AttrDict( I_0=torch.tensor(state, device=self._model.device, dtype=torch.float32), I_g=torch.tensor(goal_state, device=self._model.device, dtype=torch.float32), start_ind=torch.tensor(np.zeros((batch_size, )), device=self._model.device).long(), end_ind=torch.tensor(np.ones((batch_size, )) * (rollout_len - 1), device=self._model.device).long(), z=torch.tensor(samples, device=self._model.device, dtype=torch.float32)) input_dict = self._postprocess_inputs(input_dict) # perform rollout, collect outputs outputs = AttrDict() with self._model.val_mode(): model_output = self._model(input_dict) end_ind = torch.max(model_output.end_ind, torch.ones_like(model_output.end_ind)) # self._logs.append(model_output) if prune: outputs.predictions = self._list2np(model_output.pruned_prediction) else: outputs.predictions = self._list2np( self._get_state_rollouts(input_dict, model_output, end_ind)) outputs.actions = self._list2np( self._cap_to_length(model_output.actions, end_ind)) outputs.states = self._list2np( self._cap_to_length(model_output.regressed_state, end_ind)) outputs.latents = self._list2np( self._cap_to_length(input_dict.model_enc_seq, end_ind)) return outputs
def __getitem__(self, index): if 'one_datum' in self.data_conf and self.data_conf.one_datum: index = 1 file_index = index // self.traj_per_file path = self.filenames[file_index] try: with h5py.File(path, 'r') as F: ex_index = index % self.traj_per_file # get the index key = 'traj{}'.format(ex_index) # Fetch data into a dict if key + '/images' in F.keys(): data_dict = AttrDict(images=(F[key + '/images'].value)) else: data_dict = AttrDict() for name in F[key].keys(): if name in ['states', 'actions', 'pad_mask']: data_dict[name] = F[key + '/' + name].value.astype( np.float32) # remove spurious states at end of trajectory if self.filter_repeated_tail: data_dict = self._filter_tail(data_dict) # maybe subsample seqs if self.subsampler is not None: data_dict = self._subsample_data(data_dict) if 'robosuite_full_state' in F[key].keys(): data_dict.robosuite_full_state = F[ key + '/robosuite_full_state'].value if 'regression_state' in F[key].keys(): data_dict.states = F[key + '/regression_state'].value.astype( np.float32) # Make length consistent end_ind = np.argmax( data_dict.pad_mask * np.arange(data_dict.pad_mask.shape[0], dtype=np.float32), 0) start_ind = np.random.randint(0, end_ind - 1) if self.randomize_start else 0 start_ind, end_ind, data_dict = self.sample_max_len_video( data_dict, start_ind, end_ind) # Randomize length if self.randomize_length: end_ind = self._randomize_length(start_ind, end_ind, data_dict) # repeat last frame until end of sequence data_dict.norep_end_ind = end_ind if self.repeat_tail: data_dict, end_ind = self._repeat_tail(data_dict, end_ind) # Collect data into the format the model expects data_dict.end_ind = end_ind data_dict.start_ind = start_ind # for roboturk env rendering if 'robosuite_env_name' in F[key].keys(): data_dict.robosuite_env_name = F[ key + '/robosuite_env_name'].value if 'robosuite_xml' in F[key].keys(): data_dict.robosuite_xml = F[key + '/robosuite_xml'].value self.process_data_dict(data_dict) except: # KeyError: raise ValueError("Problem when loading file from {}".format(path)) return data_dict