コード例 #1
0
 def input_prototype(self):
     return rlt.PreprocessedStateAction(
         state=rlt.FeatureVector(
             float_features=torch.randn(1, 1, self.state_dim)),
         action=rlt.FeatureVector(
             float_features=torch.randn(1, 1, self.action_dim)),
     )
コード例 #2
0
 def input_prototype(self):
     return rlt.PreprocessedState(state=rlt.FeatureVector(
         float_features=torch.randn(1, self.state_dim),
         id_list_features={
             "page_id": (
                 torch.zeros(1, dtype=torch.long),
                 torch.ones(1, dtype=torch.long),
             )
         },
     ))
コード例 #3
0
    def preprocess(self, batch) -> rlt.RawTrainingBatch:
        state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor(
            batch["state_features"])
        next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor(
            batch["next_state_features"])

        mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1)
        sequence_numbers = torch.tensor(batch["sequence_number"],
                                        dtype=torch.int32).reshape(-1, 1)
        rewards = torch.tensor(batch["reward"],
                               dtype=torch.float32).reshape(-1, 1)
        time_diffs = torch.tensor(batch["time_diff"],
                                  dtype=torch.int32).reshape(-1, 1)
        if "action_probability" in batch:
            propensities = torch.tensor(batch["action_probability"],
                                        dtype=torch.float32).reshape(-1, 1)
        else:
            propensities = torch.ones(rewards.shape, dtype=torch.float32)

        return rlt.RawTrainingBatch(
            training_input=rlt.RawBaseInput(  # type: ignore
                state=rlt.FeatureVector(float_features=rlt.ValuePresence(
                    value=state_features_dense,
                    presence=state_features_dense_presence,
                )),
                next_state=rlt.FeatureVector(float_features=rlt.ValuePresence(
                    value=next_state_features_dense,
                    presence=next_state_features_dense_presence,
                )),
                reward=rewards,
                time_diff=time_diffs,
                step=None,
                not_terminal=None,
            ),
            extras=rlt.ExtraData(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                action_probability=propensities,
            ),
        )
コード例 #4
0
    def preprocess(self, batch) -> rlt.RawTrainingBatch:
        training_batch = super().preprocess(batch)

        actions, actions_presence = self.action_sparse_to_dense(
            batch["action"])

        next_actions, next_actions_presence = self.action_sparse_to_dense(
            batch["next_action"])

        max_action_size = max(
            len(pna) for pna in batch["possible_next_actions"])

        pnas_mask = torch.Tensor([
            ([1] * len(l) + [0] * (max_action_size - len(l)))
            for l in batch["possible_next_actions"]
        ]).byte()
        flat_pnas: List[Dict[int, float]] = []
        for pa in batch["possible_next_actions"]:
            flat_pnas.extend(pa)
            for _ in range(max_action_size - len(pa)):
                flat_pnas.append({})

        not_terminal = torch.any(next_actions_presence > 0,
                                 1).float().reshape(-1, 1)
        pnas, pnas_presence = self.action_sparse_to_dense(flat_pnas)

        base_input = cast(rlt.RawBaseInput, training_batch.training_input)
        tiled_next_state = torch.repeat_interleave(
            base_input.next_state.float_features.value, max_action_size, dim=0)
        tiled_next_state_presence = torch.repeat_interleave(
            base_input.next_state.float_features.presence,
            max_action_size,
            dim=0)

        pas_mask = torch.Tensor([
            ([1] * len(l) + [0] * (max_action_size - len(l)))
            for l in batch["possible_actions"]
        ]).byte()
        flat_pas: List[Dict[int, float]] = []
        for pa in batch["possible_actions"]:
            flat_pas.extend(pa)
            for _ in range(max_action_size - len(pa)):
                flat_pas.append({})
        pas, pas_presence = self.action_sparse_to_dense(flat_pas)

        return training_batch._replace(
            training_input=rlt.RawParametricDqnInput(
                state=base_input.state,
                reward=base_input.reward,
                time_diff=base_input.time_diff,
                action=rlt.FeatureVector(float_features=rlt.ValuePresence(
                    value=actions, presence=actions_presence)),
                next_action=rlt.FeatureVector(float_features=rlt.ValuePresence(
                    value=next_actions, presence=next_actions_presence)),
                not_terminal=not_terminal,
                next_state=base_input.next_state,
                tiled_next_state=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(
                        value=tiled_next_state,
                        presence=tiled_next_state_presence)),
                possible_actions=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(value=pas,
                                                     presence=pas_presence)),
                possible_actions_mask=pas_mask,
                possible_next_actions=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(value=pnas,
                                                     presence=pnas_presence)),
                possible_next_actions_mask=pnas_mask,
                step=None,
            ))