def get_batch(self, indices=None, torch_device=None):
        # TODO fix this
        assert indices is None
        num_eps = len(self._datadict.done)  # number of episodes
        indices = np.random.choice(num_eps, self._batch_size, replace=False)

        sampled_datadict = self._datadict.leaf_apply(
            lambda list_of_arr: np.stack(
                [self.hor_chunk(list_of_arr[i]) for i in indices]))

        inputs = AttrDict()
        outputs = AttrDict()
        for key in self._env_spec.observation_names:
            inputs[key] = sampled_datadict[key]
        for key in self._env_spec.action_names:
            inputs[key] = sampled_datadict[key]

        for key in self._env_spec.output_observation_names:
            outputs[key] = sampled_datadict[key]

        outputs.done = sampled_datadict.done.astype(bool)

        if torch_device is not None:
            for d in (inputs, outputs):
                d.leaf_modify(lambda x: torch.from_numpy(x).to(torch_device))

        return inputs, outputs  # shape is (batch, horizon, name_dim...)
    def get_batch(self,
                  indices=None,
                  torch_device=None,
                  get_horizon_goals=False,
                  get_action_seq=False,
                  min_idx=0):
        # TODO fix this
        num_eps = len(self._datadict.done)  # number of episodes
        if indices is None:
            assert 0 <= min_idx < self._data_len
            batch = min(self._data_len - min_idx, self._batch_size)
            indices = np.random.choice(self._data_len - min_idx,
                                       batch,
                                       replace=False)
            indices += min_idx  # base index to consider in dataset

        # get current batch
        sampled_datadict = self._datadict.leaf_apply(lambda arr: arr[indices])

        inputs = AttrDict()
        outputs = AttrDict()
        goals = AttrDict()
        for key in self._env_spec.observation_names:
            inputs[key] = sampled_datadict[key]
        for key in self._env_spec.action_names:
            inputs[key] = sampled_datadict[key]

        for key in self._env_spec.output_observation_names:
            outputs[key] = sampled_datadict[key]

        outputs.done = sampled_datadict.done

        if torch_device is not None:
            for d in (inputs, outputs, goals):
                d.leaf_modify(lambda x: torch.from_numpy(x).to(torch_device))

        if get_action_seq:
            inputs['act_seq'] = torch.from_numpy(
                self._action_sequences[indices]).to(torch_device)

        if get_horizon_goals:
            for key in self._env_spec.goal_names:
                goals[key] = torch.from_numpy(
                    sampled_datadict[key]).to(torch_device)

        if get_horizon_goals:
            return inputs, outputs, goals
        return inputs, outputs  # shape is (batch, horizon, name_dim...)