Exemple #1
0
    def make_traj(self, agent_data, obs, policy_out):
        traj = AttrDict()

        if not self.do_not_save_images:
            traj.images = obs['images']
        traj.states = obs['state']
        
        action_list = [action['actions'] for action in policy_out]
        traj.actions = np.stack(action_list, 0)
        
        traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions)
        traj = pad_traj_timesteps(traj, self.max_num_actions)

        if 'robosuite_xml' in obs:
            traj.robosuite_xml = obs['robosuite_xml'][0]
        if 'robosuite_env_name' in obs:
            traj.robosuite_env_name = obs['robosuite_env_name'][0]
        if 'robosuite_full_state' in obs:
            traj.robosuite_full_state = obs['robosuite_full_state']

        # minimal state that contains all information to position entities in the env
        if 'regression_state' in obs:
            traj.regression_state = obs['regression_state']

        return traj
Exemple #2
0
    def __getitem__(self, index):
        if 'one_datum' in self.data_conf and self.data_conf.one_datum:
            index = 1

        file_index = index // self.traj_per_file
        path = self.filenames[file_index]

        try:
            with h5py.File(path, 'r') as F:
                ex_index = index % self.traj_per_file  # get the index
                key = 'traj{}'.format(ex_index)

                # Fetch data into a dict
                if key + '/images' in F.keys():
                    data_dict = AttrDict(images=(F[key + '/images'].value))
                else:
                    data_dict = AttrDict()
                for name in F[key].keys():
                    if name in ['states', 'actions', 'pad_mask']:
                        data_dict[name] = F[key + '/' + name].value.astype(
                            np.float32)

                # remove spurious states at end of trajectory
                if self.filter_repeated_tail:
                    data_dict = self._filter_tail(data_dict)

                # maybe subsample seqs
                if self.subsampler is not None:
                    data_dict = self._subsample_data(data_dict)

                if 'robosuite_full_state' in F[key].keys():
                    data_dict.robosuite_full_state = F[
                        key + '/robosuite_full_state'].value
                if 'regression_state' in F[key].keys():
                    data_dict.states = F[key +
                                         '/regression_state'].value.astype(
                                             np.float32)

                # Make length consistent
                end_ind = np.argmax(
                    data_dict.pad_mask *
                    np.arange(data_dict.pad_mask.shape[0], dtype=np.float32),
                    0)
                start_ind = np.random.randint(0, end_ind -
                                              1) if self.randomize_start else 0
                start_ind, end_ind, data_dict = self.sample_max_len_video(
                    data_dict, start_ind, end_ind)

                # Randomize length
                if self.randomize_length:
                    end_ind = self._randomize_length(start_ind, end_ind,
                                                     data_dict)

                # repeat last frame until end of sequence
                data_dict.norep_end_ind = end_ind
                if self.repeat_tail:
                    data_dict, end_ind = self._repeat_tail(data_dict, end_ind)

                # Collect data into the format the model expects
                data_dict.end_ind = end_ind
                data_dict.start_ind = start_ind

                # for roboturk env rendering
                if 'robosuite_env_name' in F[key].keys():
                    data_dict.robosuite_env_name = F[
                        key + '/robosuite_env_name'].value
                if 'robosuite_xml' in F[key].keys():
                    data_dict.robosuite_xml = F[key + '/robosuite_xml'].value

                self.process_data_dict(data_dict)
        except:  # KeyError:
            raise ValueError("Problem when loading file from {}".format(path))

        return data_dict