コード例 #1
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        action = fetch_action(extract_record.action)
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        if self.max_q_learning:
            if self.sorted_action_features is not None:
                next_state = None
                tiled_next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.tiled_next_state))
            else:
                next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.next_state))
                tiled_next_state = None
            possible_next_actions = mt.PossibleActions(
                lengths=fetch(extract_record.possible_next_actions["lengths"]),
                actions=fetch_action(
                    extract_record.possible_next_actions["values"]),
            )

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_next_actions=possible_next_actions,
                reward=reward,
                not_terminal=(possible_next_actions.lengths >
                              0).float().reshape(-1, 1),
            )
        else:
            next_state = mt.FeatureVector(
                float_features=fetch(extract_record.next_state))
            next_action = fetch_action(extract_record.next_action)
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                # HACK: Need a better way to check this
                not_terminal=torch.ones_like(reward),
            )

        # TODO: stuff other fields in here
        extras = mt.ExtraData(action_probability=fetch(
            input_record.action_probability).reshape(-1, 1))

        return mt.TrainingBatch(training_input=training_input, extras=extras)
コード例 #2
0
 def as_discrete_maxq_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.MaxQLearningInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=self.next_actions,
             tiled_next_state=None,
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=None,
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
コード例 #3
0
ファイル: feature_extractor.py プロジェクト: hyzcn/Horizon
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state))
        action = fetch_action(extract_record.action)
        reward = fetch(input_record.reward)

        # is_terminal should be filled by preprocessor
        if self.max_q_learning:
            possible_next_actions = mt.PossibleActions(
                lengths=fetch(extract_record.possible_next_actions["lengths"]),
                actions=fetch_action(
                    extract_record.possible_next_actions["values"]),
            )

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                possible_next_actions=possible_next_actions,
                reward=reward,
                is_terminal=None,
            )
        else:
            next_action = fetch_action(extract_record.next_action)
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                is_terminal=None,
            )

        # TODO: stuff other fields in here
        extras = None

        return mt.TrainingBatch(training_input=training_input, extras=extras)
コード例 #4
0
 def as_parametric_maxq_training_batch(self):
     state_dim = self.states.shape[1]
     return rlt.TrainingBatch(
         training_input=rlt.MaxQLearningInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=None,
             next_action=None,
             tiled_next_state=rlt.FeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, :state_dim]),
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=rlt.FeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, state_dim:]),
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
コード例 #5
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        def fetch_possible_actions(b):
            if self.sorted_action_features is not None:
                return mt.FeatureVector(float_features=fetch(b))
            else:
                return None

        state = mt.FeatureVector(
            float_features=fetch(extract_record.state_features))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state_features))

        action = fetch_action(extract_record.action)
        next_action = fetch_action(extract_record.next_action)
        if self.multi_steps is not None:
            step = fetch(input_record.step).reshape(-1, 1)
        else:
            step = None
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        not_terminal = fetch(input_record.not_terminal).reshape(-1, 1)
        time_diff = fetch(input_record.time_diff).reshape(-1, 1)

        if self.include_possible_actions:
            # TODO: this will need to be more complicated to support sparse features
            assert self.max_num_actions is not None, "Missing max_num_actions"
            possible_actions_mask = fetch(
                extract_record.possible_actions_mask).reshape(
                    -1, self.max_num_actions)
            possible_next_actions_mask = fetch(
                extract_record.possible_next_actions_mask).reshape(
                    -1, self.max_num_actions)

            if self.sorted_action_features is not None:
                possible_actions = fetch_possible_actions(
                    extract_record.possible_actions)
                possible_next_actions = fetch_possible_actions(
                    extract_record.possible_next_actions)
                tiled_next_state = mt.FeatureVector(
                    float_features=next_state.float_features.repeat(
                        1, self.max_num_actions).reshape(
                            -1, next_state.float_features.shape[1]))
            else:
                possible_actions = None
                possible_next_actions = None
                tiled_next_state = None

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_actions=possible_actions,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=possible_next_actions,
                possible_next_actions_mask=possible_next_actions_mask,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )
        else:
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )

        # TODO: stuff other fields in here
        extras = mt.ExtraData(action_probability=fetch(
            input_record.action_probability).reshape(-1, 1))

        return mt.TrainingBatch(training_input=training_input, extras=extras)
コード例 #6
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b, to_torch=True):
            data = ws.fetch_blob(str(b()))
            if not isinstance(data, np.ndarray):
                # Blob uninitialized, return None and handle downstream
                return None
            if to_torch:
                return torch.tensor(data)
            return data

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        def fetch_possible_actions(b):
            if self.sorted_action_features is not None:
                return mt.FeatureVector(float_features=fetch(b))
            else:
                return None

        state = mt.FeatureVector(
            float_features=fetch(extract_record.state_features))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state_features))

        action = fetch_action(extract_record.action)
        next_action = fetch_action(extract_record.next_action)
        max_num_actions = None
        step = None
        if self.multi_steps is not None:
            step = fetch(input_record.step).reshape(-1, 1)
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        not_terminal = fetch(input_record.not_terminal).reshape(-1, 1)
        time_diff = fetch(input_record.time_diff).reshape(-1, 1)

        if self.include_possible_actions:
            # TODO: this will need to be more complicated to support sparse features
            assert self.max_num_actions is not None, "Missing max_num_actions"
            possible_actions_mask = (fetch(
                extract_record.possible_actions_mask).reshape(
                    -1, self.max_num_actions).type(torch.FloatTensor))
            possible_next_actions_mask = fetch(
                extract_record.possible_next_actions_mask).reshape(
                    -1, self.max_num_actions)

            if self.sorted_action_features is not None:
                possible_actions = fetch_possible_actions(
                    extract_record.possible_actions)
                possible_next_actions = fetch_possible_actions(
                    extract_record.possible_next_actions)
                tiled_next_state = mt.FeatureVector(
                    float_features=next_state.float_features.repeat(
                        1, self.max_num_actions).reshape(
                            -1, next_state.float_features.shape[1]))
                max_num_actions = self.max_num_actions

            else:
                possible_actions = None
                possible_next_actions = None
                tiled_next_state = None

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_actions=possible_actions,
                possible_actions_mask=possible_actions_mask,
                possible_next_actions=possible_next_actions,
                possible_next_actions_mask=possible_next_actions_mask,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )
        else:
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                not_terminal=not_terminal,
                step=step,
                time_diff=time_diff,
            )

        mdp_id = fetch(input_record.mdp_id, to_torch=False)
        sequence_number = fetch(input_record.sequence_number)

        metrics = fetch(
            extract_record.metrics) if self.metrics_to_score else None

        # TODO: stuff other fields in here
        extras = mt.ExtraData(
            action_probability=fetch(input_record.action_probability).reshape(
                -1, 1),
            sequence_number=sequence_number.reshape(-1, 1)
            if sequence_number is not None else None,
            mdp_id=mdp_id.reshape(-1, 1) if mdp_id is not None else None,
            max_num_actions=max_num_actions,
            metrics=metrics,
        )

        return mt.TrainingBatch(training_input=training_input, extras=extras)