예제 #1
0
class DiscreteDQN(DiscreteDQNBase):
    __hash__ = param_hash

    trainer_param: DQNTrainerParameters = field(
        default_factory=DQNTrainerParameters)
    net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `Dueling`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    cpe_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected=
                                                             FullyConnected()))

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl
        self.action_names = self.trainer_param.actions
        assert (
            len(self.action_names) > 1
        ), f"DiscreteDQNModel needs at least 2 actions. Got {self.action_names}."
        if self.trainer_param.minibatch_size % 8 != 0:
            logger.warn(
                f"minibatch size ({self.trainer_param.minibatch_size}) "
                "should be divisible by 8 for performance reasons!")

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self) -> DQNTrainer:
        net_builder = self.net_builder.value
        q_network = net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_data,
            len(self.action_names),
        )

        q_network_target = q_network.get_target_network()

        reward_network, q_network_cpe, q_network_cpe_target = None, None, None
        if self.eval_parameters.calc_cpe_in_training:
            # Metrics + reward
            num_output_nodes = (len(self.metrics_to_score) + 1) * len(
                # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`.
                self.trainer_param.actions)

            cpe_net_builder = self.cpe_net_builder.value
            reward_network = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )
            q_network_cpe = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )

            q_network_cpe_target = q_network_cpe.get_target_network()

        # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`.
        self._q_network = q_network
        trainer = DQNTrainer(
            q_network=q_network,
            q_network_target=q_network_target,
            reward_network=reward_network,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
            metrics_to_score=self.metrics_to_score,
            evaluation=self.eval_parameters,
            # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )

    def build_serving_module(self) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """
        assert self._q_network is not None, "_q_network was not initialized"

        net_builder = self.net_builder.value
        return net_builder.build_serving_module(
            self._q_network,
            self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
예제 #2
0
class Reinforce(ModelManager):
    __hash__ = param_hash

    trainer_param: ReinforceTrainerParameters = field(
        default_factory=ReinforceTrainerParameters)
    # using DQN net here because it supports `possible_actions_mask`
    policy_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-ignore
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-ignore
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])))
    sampler_temperature: float = 1.0

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.action_names = self.trainer_param.actions
        self._policy: Optional[Policy] = None
        assert (
            len(self.action_names) >
            1), f"REINFORCE needs at least 2 actions. Got {self.action_names}."

    # pyre-ignore
    def build_trainer(self) -> ReinforceTrainer:
        policy_net_builder = self.policy_net_builder.value
        # pyre-ignore
        self._policy_network = policy_net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_data,
            len(self.action_names),
        )
        value_net = None
        if self.value_net_builder:
            value_net_builder = self.value_net_builder.value  # pyre-ignore
            value_net = value_net_builder.build_value_network(
                self.state_normalization_data)
        trainer = ReinforceTrainer(
            policy=self.create_policy(),
            value_net=value_net,
            **self.trainer_param.asdict(),  # pyre-ignore
        )
        return trainer

    def create_policy(self, serving: bool = False):
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module())
        else:
            if self._policy is None:
                sampler = SoftmaxActionSampler(
                    temperature=self.sampler_temperature)
                # pyre-ignore
                self._policy = Policy(scorer=self._policy_network,
                                      sampler=sampler)
            return self._policy

    def build_serving_module(self) -> torch.nn.Module:
        assert self._policy_network is not None
        policy_serving_module = self.policy_net_builder.value.build_serving_module(
            q_network=self._policy_network,
            state_normalization_data=self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
        return policy_serving_module

    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        raise NotImplementedError

    @property
    def required_normalization_keys(self) -> List[str]:
        return [NormalizationKey.STATE]

    @property
    def should_generate_eval_dataset(self) -> bool:
        raise NotImplementedError

    def query_data(
        self,
        input_table_spec: TableSpec,
        sample_range: Optional[Tuple[float, float]],
        reward_options: RewardOptions,
    ) -> Dataset:
        raise NotImplementedError

    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: Optional[ResourceOptions],
    ) -> RLTrainingOutput:
        raise NotImplementedError

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config(
        )
예제 #3
0
class DiscreteDQN(DiscreteDQNBase):
    __hash__ = param_hash

    trainer_param: DQNTrainerParameters = field(
        default_factory=DQNTrainerParameters)
    net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `Dueling`.
        # pyre-fixme[28]: Unexpected keyword argument `Dueling`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    cpe_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected=
                                                             FullyConnected()))

    # TODO: move evaluation parameters to here from trainer_param.evaluation
    # note that only DiscreteDQN and QRDQN call RLTrainer._initialize_cpe,
    # so maybe can be removed from the RLTrainer class.

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl
        self.eval_parameters = self.trainer_param.evaluation
        self.action_names = self.trainer_param.actions
        assert len(
            self.action_names) > 1, "DiscreteDQNModel needs at least 2 actions"
        assert (
            self.trainer_param.minibatch_size % 8 == 0
        ), "The minibatch size must be divisible by 8 for performance reasons."

    def build_trainer(self) -> DQNTrainer:
        net_builder = self.net_builder.value
        q_network = net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_data,
            len(self.action_names),
        )

        if self.use_gpu:
            q_network = q_network.cuda()

        q_network_target = q_network.get_target_network()

        reward_network, q_network_cpe, q_network_cpe_target = None, None, None
        if self.trainer_param.evaluation.calc_cpe_in_training:
            # Metrics + reward
            num_output_nodes = (len(self.metrics_to_score) + 1) * len(
                self.trainer_param.actions)

            cpe_net_builder = self.cpe_net_builder.value
            reward_network = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )
            q_network_cpe = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )

            if self.use_gpu:
                reward_network.cuda()
                q_network_cpe.cuda()

            q_network_cpe_target = q_network_cpe.get_target_network()

        # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`.
        # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`.
        self._q_network = q_network
        trainer = DQNTrainer(
            q_network,
            q_network_target,
            reward_network,
            self.trainer_param,
            self.use_gpu,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
            metrics_to_score=self.metrics_to_score,
            loss_reporter=NoOpLossReporter(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """
        assert self._q_network is not None, "_q_network was not initialized"

        net_builder = self.net_builder.value
        return net_builder.build_serving_module(
            self._q_network,
            self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
예제 #4
0
class DiscreteCRR(DiscreteDQNBase):
    __hash__ = param_hash

    trainer_param: CRRTrainerParameters = field(
        default_factory=CRRTrainerParameters)

    actor_net_builder: DiscreteActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: DiscreteActorNetBuilder__Union(
            FullyConnected=DiscreteFullyConnected()))

    critic_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )

    cpe_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected=
                                                             FullyConnected()))

    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self._q1_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl
        self.action_names = self.trainer_param.actions
        assert (
            len(self.action_names) > 1
        ), f"DiscreteDQNModel needs at least 2 actions. Got {self.action_names}."

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self, use_gpu: bool) -> DiscreteCRRTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `DiscreteCRR` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, len(self.action_names))

        # The arguments to q_network1 and q_network2 below are modeled after those in discrete_dqn.py
        # The target networks will be created in DiscreteCRRTrainer
        critic_net_builder = self.critic_net_builder.value

        # pyre-fixme[16]: `DiscreteCRR` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_data,
            len(self.action_names),
        )

        q2_network = (
            critic_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                len(self.action_names),
            )
            # pyre-fixme[16]: `CRRTrainerParameters` has no attribute
            #  `double_q_learning`.
            if self.trainer_param.double_q_learning else None)

        reward_network, q_network_cpe, q_network_cpe_target = None, None, None
        if self.eval_parameters.calc_cpe_in_training:
            # Metrics + reward
            num_output_nodes = (len(self.metrics_to_score) + 1) * len(
                # pyre-fixme[16]: `CRRTrainerParameters` has no attribute `actions`.
                self.trainer_param.actions)

            cpe_net_builder = self.cpe_net_builder.value
            reward_network = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )
            q_network_cpe = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_data,
                num_output_nodes,
            )

            q_network_cpe_target = q_network_cpe.get_target_network()

        trainer = DiscreteCRRTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            reward_network=reward_network,
            q2_network=q2_network,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
            metrics_to_score=self.metrics_to_score,
            evaluation=self.eval_parameters,
            # pyre-fixme[16]: `CRRTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def create_policy(self, serving: bool) -> Policy:
        """Create online actor critic policy."""
        if serving:
            return create_predictor_policy_from_model(
                self.build_actor_module())
        else:
            return ActorPolicyWrapper(self._actor_network)

    def get_reporter(self):
        return DiscreteCRRReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )

    # Note: when using test_gym.py as the entry point, the normalization data
    # is set when the line     normalization = build_normalizer(env)   is executed.
    # The code then calls build_state_normalizer() and build_action_normalizer()
    # in utils.py

    def serving_module_names(self):
        module_names = ["default_model", "dqn", "actor_dqn"]
        if len(self.action_names) == 2:
            module_names.append("binary_difference_scorer")
        return module_names

    def build_serving_modules(self):
        """
        `actor_dqn` is the actor module wrapped in the DQN predictor wrapper.
        This helps putting the actor in places where DQN predictor wrapper is expected.
        If the policy is greedy, then this wrapper would work.
        """
        serving_modules = {
            "default_model": self.build_actor_module(),
            "dqn": self._build_dqn_module(self._q1_network),
            "actor_dqn": self._build_dqn_module(ActorDQN(self._actor_network)),
        }
        if len(self.action_names) == 2:
            serving_modules.update({
                "binary_difference_scorer":
                self._build_binary_difference_scorer(
                    ActorDQN(self._actor_network)),
            })
        return serving_modules

    def _build_dqn_module(self, network):
        critic_net_builder = self.critic_net_builder.value
        assert network is not None
        return critic_net_builder.build_serving_module(
            network,
            self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )

    def _build_binary_difference_scorer(self, network):
        critic_net_builder = self.critic_net_builder.value
        assert network is not None
        return critic_net_builder.build_binary_difference_scorer(
            network,
            self.state_normalization_data,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )

    # Also, even though the build_serving_module below is directed to
    # discrete_actor_net_builder.py, which returns ActorPredictorWrapper,
    # just like in the continuous_actor_net_builder.py, the outputs of the
    # discrete actor will still be computed differently from those of the
    # continuous actor because during serving, the act() function for the
    # Agent class in gym/agents/agents.py returns
    # self.action_extractor(actor_output), which is created in
    # create_for_env_with_serving_policy, when
    # env.get_serving_action_extractor() is called. During serving,
    # action_extractor calls serving_action_extractor() in env_wrapper.py,
    # which checks the type of action_space during serving time and treats
    # spaces.Discrete differently from spaces.Box (continuous).
    def build_actor_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            action_feature_ids=list(range(len(self.action_names))),
        )
예제 #5
0
class PPO(ModelManager):
    __hash__ = param_hash

    trainer_param: PPOTrainerParameters = field(default_factory=PPOTrainerParameters)
    # using DQN net here because it supports `possible_actions_mask`
    policy_net_builder: DiscreteDQNNetBuilder__Union = field(
        # pyre-ignore
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling())
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = None
    state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
        # pyre-ignore
        default_factory=lambda: ModelFeatureConfigProvider__Union(
            raw=RawModelFeatureConfigProvider(float_feature_infos=[])
        )
    )
    sampler_temperature: float = 1.0

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._policy: Optional[Policy] = None
        assert (
            len(self.action_names) > 1
        ), f"PPO needs at least 2 actions. Got {self.action_names}."

    @property
    def action_names(self):
        return self.trainer_param.actions

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> PPOTrainer:
        policy_net_builder = self.policy_net_builder.value
        policy_network = policy_net_builder.build_q_network(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            len(self.action_names),
        )
        value_net = None
        value_net_builder = self.value_net_builder
        if value_net_builder:
            value_net_builder = value_net_builder.value
            value_net = value_net_builder.build_value_network(
                normalization_data_map[NormalizationKey.STATE]
            )
        trainer = PPOTrainer(
            policy=self._create_policy(policy_network),
            value_net=value_net,
            **self.trainer_param.asdict(),  # pyre-ignore
        )
        return trainer

    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ):
        assert isinstance(trainer_module, PPOTrainer)
        if serving:
            assert normalization_data_map is not None
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module, normalization_data_map)
            )
        else:
            return self._create_policy(trainer_module.scorer)

    def _create_policy(self, policy_network):
        if self._policy is None:
            sampler = SoftmaxActionSampler(temperature=self.sampler_temperature)
            self._policy = Policy(scorer=policy_network, sampler=sampler)
        return self._policy

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, PPOTrainer)
        policy_serving_module = self.policy_net_builder.value.build_serving_module(
            q_network=trainer_module.scorer,
            state_normalization_data=normalization_data_map[NormalizationKey.STATE],
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )
        return policy_serving_module

    @property
    def state_feature_config(self) -> rlt.ModelFeatureConfig:
        return self.state_feature_config_provider.value.get_model_feature_config()
예제 #6
0
class DiscreteDQN(DiscreteDQNBase):
    __hash__ = param_hash

    trainer_param: DQNTrainerParameters = field(
        default_factory=DQNTrainerParameters)
    net_builder: DiscreteDQNNetBuilder__Union = field(
        default_factory=lambda: DiscreteDQNNetBuilder__Union(Dueling=Dueling()
                                                             ))
    cpe_net_builder: DiscreteDQNNetBuilder__Union = field(
        default_factory=lambda: DiscreteDQNNetBuilder__Union(FullyConnected=
                                                             FullyConnected()))

    def __post_init_post_parse__(self, ):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl
        self.eval_parameters = self.trainer_param.evaluation
        self.action_names = self.trainer_param.actions
        assert len(
            self.action_names) > 1, "DiscreteDQNModel needs at least 2 actions"
        assert (
            self.trainer_param.minibatch_size % 8 == 0
        ), "The minibatch size must be divisible by 8 for performance reasons."

    def build_trainer(self) -> DQNTrainer:
        net_builder = self.net_builder.value
        q_network = net_builder.build_q_network(
            self.state_feature_config,
            self.state_normalization_parameters,
            len(self.action_names),
        )

        if self.use_gpu:
            q_network = q_network.cuda()

        q_network_target = q_network.get_target_network()

        reward_network, q_network_cpe, q_network_cpe_target = None, None, None
        if self.trainer_param.evaluation.calc_cpe_in_training:
            # Metrics + reward
            num_output_nodes = (len(self.metrics_to_score) + 1) * len(
                self.trainer_param.actions)

            cpe_net_builder = self.cpe_net_builder.value
            reward_network = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_parameters,
                num_output_nodes,
            )
            q_network_cpe = cpe_net_builder.build_q_network(
                self.state_feature_config,
                self.state_normalization_parameters,
                num_output_nodes,
            )

            if self.use_gpu:
                reward_network.cuda()
                q_network_cpe.cuda()

            q_network_cpe_target = q_network_cpe.get_target_network()

        self._q_network = q_network
        trainer = DQNTrainer(
            q_network,
            q_network_target,
            reward_network,
            self.trainer_param,
            self.use_gpu,
            q_network_cpe=q_network_cpe,
            q_network_cpe_target=q_network_cpe_target,
            metrics_to_score=self.metrics_to_score,
            loss_reporter=NoOpLossReporter(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        """
        Returns a TorchScript predictor module
        """
        assert self._q_network is not None, "_q_network was not initialized"
        net_builder = self.net_builder.value
        return net_builder.build_serving_module(
            self._q_network,
            self.state_normalization_parameters,
            action_names=self.action_names,
            state_feature_config=self.state_feature_config,
        )