Exemplo n.º 1
0
    def sample_memories(self, batch_size, use_gpu=False, batch_first=False):
        """
        :param batch_size: number of samples to return
        :param use_gpu: whether to put samples on gpu
        :param batch_first: If True, the first dimension of data is batch_size.
            If False (default), the first dimension is SEQ_LEN. Therefore,
            state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default,
            MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        device = torch.device("cuda") if use_gpu else torch.device("cpu")
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: stack(x).float().to(device),
            zip(*self.deque_sample(sample_indices)),
        )

        if not batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                state, action, next_state, reward, not_terminal)

        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.PreprocessedFeatureVector(float_features=state),
            reward=reward,
            time_diff=torch.ones_like(reward).float(),
            action=action,
            next_state=rlt.PreprocessedFeatureVector(
                float_features=next_state),
            not_terminal=not_terminal,
            step=None,
        )
        return rlt.PreprocessedTrainingBatch(training_input=training_input,
                                             extras=None)
    def sample_memories(self, batch_size, model_type, chunk=None):
        """
        Samples transitions from replay memory uniformly at random by default
        or pass chunk for deterministic sample.

        *Note*: 1-D vectors such as state & action get stacked to make a 2-D
        matrix, while a 2-D matrix such as possible_actions (in the parametric
        case) get concatenated to make a bigger 2-D matrix

        :param batch_size: Number of sampled transitions to return.
        :param model_type: Model type (discrete, parametric).
        :param chunk: Index of chunk of data (for deterministic sampling).
        """
        cols = [[], [], [], [], [], [], [], [], [], [], [], []]

        if chunk is None:
            indices = np.random.randint(0,
                                        len(self.replay_memory),
                                        size=batch_size)
        else:
            start_idx = chunk * batch_size
            end_idx = start_idx + batch_size
            indices = range(start_idx, end_idx)

        for idx in indices:
            memory = self.replay_memory[idx]
            for col, value in zip(cols, memory):
                col.append(value)

        states = stack(cols[0])
        next_states = stack(cols[3])

        assert states.dim() == 2
        assert next_states.dim() == 2

        if model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value:
            num_possible_actions = len(cols[7][0])

            actions = stack(cols[1])
            next_actions = stack(cols[4])

            tiled_states = states.repeat(1, num_possible_actions).reshape(
                -1, states.shape[1])
            possible_actions = torch.cat(cols[8])
            possible_actions_state_concat = torch.cat(
                (tiled_states, possible_actions), dim=1)
            possible_actions_mask = stack(cols[9])

            tiled_next_states = next_states.repeat(
                1, num_possible_actions).reshape(-1, next_states.shape[1])
            possible_next_actions = torch.cat(cols[6])
            possible_next_actions_state_concat = torch.cat(
                (tiled_next_states, possible_next_actions), dim=1)
            possible_next_actions_mask = stack(cols[7])
        else:
            possible_actions = None
            possible_actions_state_concat = None
            possible_next_actions = None
            possible_next_actions_state_concat = None
            if cols[7] is None or cols[7][0] is None:
                possible_next_actions_mask = None
            else:
                possible_next_actions_mask = stack(cols[7])
            if cols[9] is None or cols[9][0] is None:
                possible_actions_mask = None
            else:
                possible_actions_mask = stack(cols[9])

            actions = stack(cols[1])
            next_actions = stack(cols[4])

            assert len(actions.size()) == 2
            assert len(next_actions.size()) == 2

        rewards = torch.tensor(cols[2], dtype=torch.float32).reshape(-1, 1)
        not_terminal = (1 - torch.tensor(cols[5], dtype=torch.int32)).reshape(
            -1, 1)
        time_diffs = torch.tensor(cols[10], dtype=torch.int32).reshape(-1, 1)

        return TrainingDataPage(
            states=states,
            actions=actions,
            propensities=None,
            rewards=rewards,
            next_states=next_states,
            next_actions=next_actions,
            not_terminal=not_terminal,
            time_diffs=time_diffs,
            possible_actions_mask=possible_actions_mask,
            possible_actions_state_concat=possible_actions_state_concat,
            possible_next_actions_mask=possible_next_actions_mask,
            possible_next_actions_state_concat=
            possible_next_actions_state_concat,
        )