Ejemplo n.º 1
0
    def sample_memories(self,
                        batch_size,
                        use_gpu=False) -> rlt.MemoryNetworkInput:
        """
        :param batch_size: number of samples to return
        :param use_gpu: whether to put samples on gpu
        State's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example.
        By default, MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        device = torch.device("cuda") if use_gpu else torch.device("cpu")
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: stack(x).float().to(device),
            zip(*self.deque_sample(sample_indices)),
        )

        # make shapes seq_len x batch_size x feature_dim
        state, action, next_state, reward, not_terminal = transpose(
            state, action, next_state, reward, not_terminal)

        return rlt.MemoryNetworkInput(
            state=rlt.FeatureData(float_features=state),
            reward=reward,
            time_diff=torch.ones_like(reward).float(),
            action=action,
            next_state=rlt.FeatureData(float_features=next_state),
            not_terminal=not_terminal,
            step=None,
        )
Ejemplo n.º 2
0
    def __call__(self, batch):
        action = batch.action
        if self.num_actions is not None:
            assert len(action.shape) == 2, f"{action.shape}"
            # one hot makes shape (batch_size, stack_size, feature_dim)
            action = F.one_hot(batch.action, self.num_actions).float()
            # make shape to (batch_size, feature_dim, stack_size)
            action = action.transpose(1, 2)

        # For (1-dimensional) vector fields, RB returns (batch_size, state_dim)
        # or (batch_size, state_dim, stack_size).
        # We want these to all be (stack_size, batch_size, state_dim), so
        # unsqueeze the former case; Note this only happens for stack_size = 1.
        # Then, permute.
        permutation = [2, 0, 1]
        vector_fields = {
            "state": batch.state,
            "action": action,
            "next_state": batch.next_state,
        }
        for name, tensor in vector_fields.items():
            if len(tensor.shape) == 2:
                tensor = tensor.unsqueeze(2)
            assert len(tensor.shape) == 3, f"{name} has shape {tensor.shape}"
            vector_fields[name] = tensor.permute(permutation)

        # For scalar fields, RB returns (batch_size), or (batch_size, stack_size)
        # Do same as above, except transpose instead.
        scalar_fields = {
            "reward": batch.reward,
            "not_terminal": 1.0 - batch.terminal.float(),
        }
        for name, tensor in scalar_fields.items():
            if len(tensor.shape) == 1:
                tensor = tensor.unsqueeze(1)
            assert len(tensor.shape) == 2, f"{name} has shape {tensor.shape}"
            scalar_fields[name] = tensor.transpose(0, 1)

        # stack_size > 1, so let's pad not_terminal with 1's, since
        # previous states couldn't have been terminal..
        if scalar_fields["reward"].shape[0] > 1:
            batch_size = scalar_fields["reward"].shape[1]
            assert scalar_fields["not_terminal"].shape == (
                1,
                batch_size,
            ), f"{scalar_fields['not_terminal'].shape}"
            stacked_not_terminal = torch.ones_like(scalar_fields["reward"])
            stacked_not_terminal[-1] = scalar_fields["not_terminal"]
            scalar_fields["not_terminal"] = stacked_not_terminal

        return rlt.MemoryNetworkInput(
            state=rlt.FeatureData(float_features=vector_fields["state"]),
            next_state=rlt.FeatureData(
                float_features=vector_fields["next_state"]),
            action=vector_fields["action"],
            reward=scalar_fields["reward"],
            not_terminal=scalar_fields["not_terminal"],
            step=None,
            time_diff=None,
        )
Ejemplo n.º 3
0
 def test_get_Q(self):
     NUM_ACTION = 2
     MULTI_STEPS = 3
     BATCH_SIZE = 2
     STATE_DIM = 4
     all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION)
     seq2reward_network = FakeSeq2RewardNetwork()
     batch = rlt.MemoryNetworkInput(
         state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         next_state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         action=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION)
         ),
         reward=torch.zeros(1),
         time_diff=torch.zeros(1),
         step=torch.zeros(1),
         not_terminal=torch.zeros(1),
     )
     q_values = get_Q(seq2reward_network, batch, all_permut)
     expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]])
     logger.info(f"q_values: {q_values}")
     assert torch.all(expected_q_values == q_values)