示例#1
0
def create_data(state_dim, action_dim, seq_len, batch_size, num_batches):
    SCALE = 2
    weight = SCALE * torch.randn(state_dim + action_dim)
    data = [None for _ in range(num_batches)]
    for i in range(num_batches):
        state = SCALE * torch.randn(seq_len, batch_size, state_dim)
        action = SCALE * torch.randn(seq_len, batch_size, action_dim)
        # random valid step
        valid_step = torch.randint(1, seq_len + 1, (batch_size, 1))

        # reward_matrix shape: batch_size x seq_len
        reward_matrix = torch.matmul(
            torch.cat((state, action), dim=2), weight
        ).transpose(0, 1)
        mask = torch.arange(seq_len).repeat(batch_size, 1)
        mask = (mask >= (seq_len - valid_step)).float()
        reward = (reward_matrix * mask).sum(dim=1).reshape(-1, 1)
        data[i] = rlt.MemoryNetworkInput(
            state=rlt.FeatureData(state),
            action=action,
            valid_step=valid_step,
            reward=reward,
            # the rest fields will not be used
            next_state=torch.tensor([]),
            step=torch.tensor([]),
            not_terminal=torch.tensor([]),
            time_diff=torch.tensor([]),
        )
    return weight, data
示例#2
0
    def sample_memories(self, batch_size, use_gpu=False) -> rlt.MemoryNetworkInput:
        """
        :param batch_size: number of samples to return
        :param use_gpu: whether to put samples on gpu
        State's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example.
        By default, MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        device = torch.device("cuda") if use_gpu else torch.device("cpu")
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: stack(x).float().to(device),
            zip(*self.deque_sample(sample_indices)),
        )

        # make shapes seq_len x batch_size x feature_dim
        state, action, next_state, reward, not_terminal = transpose(
            state, action, next_state, reward, not_terminal
        )

        return rlt.MemoryNetworkInput(
            state=rlt.FeatureData(float_features=state),
            reward=reward,
            time_diff=torch.ones_like(reward).float(),
            action=action,
            next_state=rlt.FeatureData(float_features=next_state),
            not_terminal=not_terminal,
            step=None,
        )
示例#3
0
def create_sequence_data(state_dim, action_dim, seq_len, batch_size,
                         num_batches):
    SCALE = 2
    weight = SCALE * torch.randn(state_dim + action_dim)

    data = [None for _ in range(num_batches)]

    for i in range(num_batches):
        state = SCALE * torch.randn(seq_len, batch_size, state_dim)
        action = SCALE * torch.randn(seq_len, batch_size, action_dim)
        # random valid step
        valid_step = torch.randint(1, seq_len + 1, (batch_size, 1))

        feature_mask = torch.arange(seq_len).repeat(batch_size, 1)
        feature_mask = (feature_mask >= (seq_len - valid_step)).float()
        assert feature_mask.shape == (batch_size, seq_len), feature_mask.shape
        feature_mask = feature_mask.transpose(0, 1).unsqueeze(-1)
        assert feature_mask.shape == (seq_len, batch_size,
                                      1), feature_mask.shape

        feature = torch.cat((state, action), dim=2)
        masked_feature = feature * feature_mask

        # seq_len, batch_size, state_dim + action_dim
        left_shifted = torch.cat(
            (
                masked_feature.narrow(0, 1, seq_len - 1),
                torch.zeros(1, batch_size, state_dim + action_dim),
            ),
            dim=0,
        )
        # seq_len, batch_size, state_dim + action_dim
        right_shifted = torch.cat(
            (
                torch.zeros(1, batch_size, state_dim + action_dim),
                masked_feature.narrow(0, 0, seq_len - 1),
            ),
            dim=0,
        )
        # reward_matrix shape: batch_size x seq_len
        reward_matrix = torch.matmul(left_shifted + right_shifted,
                                     weight).transpose(0, 1)

        mask = torch.arange(seq_len).repeat(batch_size, 1)
        mask = (mask >= (seq_len - valid_step)).float()
        reward = (reward_matrix * mask).sum(dim=1).reshape(-1, 1)

        data[i] = rlt.MemoryNetworkInput(
            state=rlt.FeatureData(state),
            action=action,
            valid_step=valid_step,
            reward=reward,
            # the rest fields will not be used
            next_state=torch.tensor([]),
            step=torch.tensor([]),
            not_terminal=torch.tensor([]),
            time_diff=torch.tensor([]),
        )

    return weight, data
示例#4
0
    def __call__(self, batch):
        action = batch.action
        if self.num_actions is not None:
            assert len(action.shape) == 2, f"{action.shape}"
            # one hot makes shape (batch_size, stack_size, feature_dim)
            action = F.one_hot(batch.action, self.num_actions).float()
            # make shape to (batch_size, feature_dim, stack_size)
            action = action.transpose(1, 2)

        # For (1-dimensional) vector fields, RB returns (batch_size, state_dim)
        # or (batch_size, state_dim, stack_size).
        # We want these to all be (stack_size, batch_size, state_dim), so
        # unsqueeze the former case; Note this only happens for stack_size = 1.
        # Then, permute.
        permutation = [2, 0, 1]
        vector_fields = {
            "state": batch.state,
            "action": action,
            "next_state": batch.next_state,
        }
        for name, tensor in vector_fields.items():
            if len(tensor.shape) == 2:
                tensor = tensor.unsqueeze(2)
            assert len(tensor.shape) == 3, f"{name} has shape {tensor.shape}"
            vector_fields[name] = tensor.permute(permutation)

        # For scalar fields, RB returns (batch_size), or (batch_size, stack_size)
        # Do same as above, except transpose instead.
        scalar_fields = {
            "reward": batch.reward,
            "not_terminal": 1.0 - batch.terminal.float(),
        }
        for name, tensor in scalar_fields.items():
            if len(tensor.shape) == 1:
                tensor = tensor.unsqueeze(1)
            assert len(tensor.shape) == 2, f"{name} has shape {tensor.shape}"
            scalar_fields[name] = tensor.transpose(0, 1)

        # stack_size > 1, so let's pad not_terminal with 1's, since
        # previous states couldn't have been terminal..
        if scalar_fields["reward"].shape[0] > 1:
            batch_size = scalar_fields["reward"].shape[1]
            assert scalar_fields["not_terminal"].shape == (
                1,
                batch_size,
            ), f"{scalar_fields['not_terminal'].shape}"
            stacked_not_terminal = torch.ones_like(scalar_fields["reward"])
            stacked_not_terminal[-1] = scalar_fields["not_terminal"]
            scalar_fields["not_terminal"] = stacked_not_terminal

        return rlt.MemoryNetworkInput(
            state=rlt.FeatureData(float_features=vector_fields["state"]),
            next_state=rlt.FeatureData(float_features=vector_fields["next_state"]),
            action=vector_fields["action"],
            reward=scalar_fields["reward"],
            not_terminal=scalar_fields["not_terminal"],
            step=None,
            time_diff=None,
        )
def _create_input():
    state = torch.randn(SEQ_LEN, BATCH_SIZE, STATE_DIM)
    valid_step = torch.tensor([[1], [4]])
    action = torch.tensor([
        [[0, 1], [1, 0]],
        [[0, 1], [1, 0]],
        [[1, 0], [0, 1]],
        [[0, 1], [1, 0]],
    ])
    input = rlt.MemoryNetworkInput(
        state=rlt.FeatureData(state),
        action=action,
        valid_step=valid_step,
        # the rest fields will not be used
        next_state=torch.tensor([]),
        reward=torch.tensor([]),
        step=torch.tensor([]),
        not_terminal=torch.tensor([]),
        time_diff=torch.tensor([]),
    )
    return input
示例#6
0
def _create_preprocessed_input(
    input: rlt.MemoryNetworkInput,
    state_preprocessor: Preprocessor,
    action_preprocessor: Preprocessor,
):
    preprocessed_state = state_preprocessor(
        input.state.float_features.reshape(SEQ_LEN * BATCH_SIZE, STATE_DIM),
        torch.ones(SEQ_LEN * BATCH_SIZE, STATE_DIM),
    ).reshape(SEQ_LEN, BATCH_SIZE, STATE_DIM)
    preprocessed_action = action_preprocessor(
        input.action.reshape(SEQ_LEN * BATCH_SIZE, ACTION_DIM),
        torch.ones(SEQ_LEN * BATCH_SIZE, ACTION_DIM),
    ).reshape(SEQ_LEN, BATCH_SIZE, ACTION_DIM)
    return rlt.MemoryNetworkInput(
        state=rlt.FeatureData(preprocessed_state),
        action=preprocessed_action,
        valid_step=input.valid_step,
        next_state=input.next_state,
        reward=input.reward,
        step=input.step,
        not_terminal=input.not_terminal,
        time_diff=input.time_diff,
    )
示例#7
0
def _create_input():
    state = torch.randn(SEQ_LEN, BATCH_SIZE, STATE_DIM)
    # generate valid_step with shape (BATCH_SIZE, 1), values ranging from [1, SEQ_LEN] (inclusive)
    valid_step = torch.randint(1, SEQ_LEN + 1, size=(BATCH_SIZE, 1))
    # create one-hot action value
    action_label = torch.LongTensor(SEQ_LEN * BATCH_SIZE, 1) % ACTION_DIM
    action = torch.FloatTensor(SEQ_LEN * BATCH_SIZE, ACTION_DIM)
    action.zero_()
    action.scatter_(1, action_label, 1)
    action = action.reshape(SEQ_LEN, BATCH_SIZE, ACTION_DIM)

    input = rlt.MemoryNetworkInput(
        state=rlt.FeatureData(state),
        action=action,
        valid_step=valid_step,
        # the rest fields will not be used
        next_state=torch.tensor([]),
        reward=torch.tensor([]),
        step=torch.tensor([]),
        not_terminal=torch.tensor([]),
        time_diff=torch.tensor([]),
    )
    return input
示例#8
0
def create_string_game_data(dataset_size=10000,
                            training_data_ratio=0.9,
                            filter_short_sequence=False):
    SEQ_LEN = 6
    NUM_ACTION = 2
    NUM_MDP_PER_BATCH = 5

    env = Gym(env_name="StringGame-v0", set_max_steps=SEQ_LEN)
    df = create_df_from_replay_buffer(
        env=env,
        problem_domain=ProblemDomain.DISCRETE_ACTION,
        desired_size=dataset_size,
        multi_steps=None,
        ds="2020-10-10",
    )

    if filter_short_sequence:
        batch_size = NUM_MDP_PER_BATCH
        time_diff = torch.ones(SEQ_LEN, batch_size)
        valid_step = SEQ_LEN * torch.ones(batch_size, dtype=torch.int64)[:,
                                                                         None]
        not_terminal = torch.Tensor(
            [0 if i == SEQ_LEN - 1 else 1 for i in range(SEQ_LEN)])
        not_terminal = torch.transpose(not_terminal.tile(NUM_MDP_PER_BATCH, 1),
                                       0, 1)
    else:
        batch_size = NUM_MDP_PER_BATCH * SEQ_LEN
        time_diff = torch.ones(SEQ_LEN, batch_size)
        valid_step = torch.arange(SEQ_LEN, 0, -1).tile(NUM_MDP_PER_BATCH)[:,
                                                                          None]
        not_terminal = torch.transpose(
            torch.tril(torch.ones(SEQ_LEN, SEQ_LEN),
                       diagonal=-1).tile(NUM_MDP_PER_BATCH, 1),
            0,
            1,
        )

    num_batches = int(dataset_size / SEQ_LEN / NUM_MDP_PER_BATCH)
    batches = [None for _ in range(num_batches)]
    batch_count, batch_seq_count = 0, 0
    batch_reward = torch.zeros(SEQ_LEN, batch_size)
    batch_action = torch.zeros(SEQ_LEN, batch_size, NUM_ACTION)
    batch_state = torch.zeros(SEQ_LEN, batch_size, NUM_ACTION)
    for mdp_id in sorted(set(df.mdp_id)):
        mdp = df[df["mdp_id"] == mdp_id].sort_values("sequence_number",
                                                     ascending=True)
        if len(mdp) != SEQ_LEN:
            continue

        all_step_reward = torch.Tensor(list(mdp["reward"]))
        all_step_state = torch.Tensor(
            [list(s.values()) for s in mdp["state_features"]])
        all_step_action = torch.zeros_like(all_step_state)
        all_step_action[torch.arange(SEQ_LEN),
                        [int(a) for a in mdp["action"]]] = 1.0

        for j in range(SEQ_LEN):
            if filter_short_sequence and j > 0:
                break

            reward = torch.zeros_like(all_step_reward)
            reward[:SEQ_LEN - j] = all_step_reward[-(SEQ_LEN - j):]
            batch_reward[:, batch_seq_count] = reward

            state = torch.zeros_like(all_step_state)
            state[:SEQ_LEN - j] = all_step_state[-(SEQ_LEN - j):]
            batch_state[:, batch_seq_count] = state

            action = torch.zeros_like(all_step_action)
            action[:SEQ_LEN - j] = all_step_action[-(SEQ_LEN - j):]
            batch_action[:, batch_seq_count] = action

            batch_seq_count += 1

        if batch_seq_count == batch_size:
            batches[batch_count] = rlt.MemoryNetworkInput(
                reward=batch_reward,
                action=batch_action,
                state=rlt.FeatureData(float_features=batch_state),
                next_state=rlt.FeatureData(float_features=torch.zeros_like(
                    batch_state)),  # fake, not used anyway
                not_terminal=not_terminal,
                time_diff=time_diff,
                valid_step=valid_step,
                step=None,
            )
            batch_count += 1
            batch_seq_count = 0
            batch_reward = torch.zeros_like(batch_reward)
            batch_action = torch.zeros_like(batch_action)
            batch_state = torch.zeros_like(batch_state)
    assert batch_count == num_batches

    num_training_batches = int(training_data_ratio * num_batches)
    training_data = batches[:num_training_batches]
    eval_data = batches[num_training_batches:]
    return training_data, eval_data