예제 #1
0
    def make_traj(self, agent_data, obs, policy_out):
        traj = AttrDict()

        if not self.do_not_save_images:
            traj.images = obs['images']
        traj.states = obs['state']
        
        action_list = [action['actions'] for action in policy_out]
        traj.actions = np.stack(action_list, 0)
        
        traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions)
        traj = pad_traj_timesteps(traj, self.max_num_actions)

        if 'robosuite_xml' in obs:
            traj.robosuite_xml = obs['robosuite_xml'][0]
        if 'robosuite_env_name' in obs:
            traj.robosuite_env_name = obs['robosuite_env_name'][0]
        if 'robosuite_full_state' in obs:
            traj.robosuite_full_state = obs['robosuite_full_state']

        # minimal state that contains all information to position entities in the env
        if 'regression_state' in obs:
            traj.regression_state = obs['regression_state']

        return traj
예제 #2
0
    def rollout(self, state, goal_state, samples, rollout_len, prune=False):
        """Performs one model rollout."""
        # prepare inputs
        batch_size = samples.shape[0]
        state, goal_state = state.repeat(batch_size,
                                         0), goal_state.repeat(batch_size, 0)
        input_dict = AttrDict(
            I_0=torch.tensor(state,
                             device=self._model.device,
                             dtype=torch.float32),
            I_g=torch.tensor(goal_state,
                             device=self._model.device,
                             dtype=torch.float32),
            start_ind=torch.tensor(np.zeros((batch_size, )),
                                   device=self._model.device).long(),
            end_ind=torch.tensor(np.ones((batch_size, )) * (rollout_len - 1),
                                 device=self._model.device).long(),
            z=torch.tensor(samples,
                           device=self._model.device,
                           dtype=torch.float32))
        input_dict = self._postprocess_inputs(input_dict)

        # perform rollout, collect outputs
        outputs = AttrDict()
        with self._model.val_mode():
            model_output = self._model(input_dict)
            end_ind = torch.max(model_output.end_ind,
                                torch.ones_like(model_output.end_ind))
        # self._logs.append(model_output)

        if prune:
            outputs.predictions = self._list2np(model_output.pruned_prediction)
        else:
            outputs.predictions = self._list2np(
                self._get_state_rollouts(input_dict, model_output, end_ind))

        outputs.actions = self._list2np(
            self._cap_to_length(model_output.actions, end_ind))
        outputs.states = self._list2np(
            self._cap_to_length(model_output.regressed_state, end_ind))
        outputs.latents = self._list2np(
            self._cap_to_length(input_dict.model_enc_seq, end_ind))

        return outputs
예제 #3
0
    def __getitem__(self, index):
        if 'one_datum' in self.data_conf and self.data_conf.one_datum:
            index = 1

        file_index = index // self.traj_per_file
        path = self.filenames[file_index]

        try:
            with h5py.File(path, 'r') as F:
                ex_index = index % self.traj_per_file  # get the index
                key = 'traj{}'.format(ex_index)

                # Fetch data into a dict
                if key + '/images' in F.keys():
                    data_dict = AttrDict(images=(F[key + '/images'].value))
                else:
                    data_dict = AttrDict()
                for name in F[key].keys():
                    if name in ['states', 'actions', 'pad_mask']:
                        data_dict[name] = F[key + '/' + name].value.astype(
                            np.float32)

                # remove spurious states at end of trajectory
                if self.filter_repeated_tail:
                    data_dict = self._filter_tail(data_dict)

                # maybe subsample seqs
                if self.subsampler is not None:
                    data_dict = self._subsample_data(data_dict)

                if 'robosuite_full_state' in F[key].keys():
                    data_dict.robosuite_full_state = F[
                        key + '/robosuite_full_state'].value
                if 'regression_state' in F[key].keys():
                    data_dict.states = F[key +
                                         '/regression_state'].value.astype(
                                             np.float32)

                # Make length consistent
                end_ind = np.argmax(
                    data_dict.pad_mask *
                    np.arange(data_dict.pad_mask.shape[0], dtype=np.float32),
                    0)
                start_ind = np.random.randint(0, end_ind -
                                              1) if self.randomize_start else 0
                start_ind, end_ind, data_dict = self.sample_max_len_video(
                    data_dict, start_ind, end_ind)

                # Randomize length
                if self.randomize_length:
                    end_ind = self._randomize_length(start_ind, end_ind,
                                                     data_dict)

                # repeat last frame until end of sequence
                data_dict.norep_end_ind = end_ind
                if self.repeat_tail:
                    data_dict, end_ind = self._repeat_tail(data_dict, end_ind)

                # Collect data into the format the model expects
                data_dict.end_ind = end_ind
                data_dict.start_ind = start_ind

                # for roboturk env rendering
                if 'robosuite_env_name' in F[key].keys():
                    data_dict.robosuite_env_name = F[
                        key + '/robosuite_env_name'].value
                if 'robosuite_xml' in F[key].keys():
                    data_dict.robosuite_xml = F[key + '/robosuite_xml'].value

                self.process_data_dict(data_dict)
        except:  # KeyError:
            raise ValueError("Problem when loading file from {}".format(path))

        return data_dict