Ejemplo n.º 1
0
class ParametricDQN(ParametricDQNBase):
    __hash__ = param_hash

    trainer_param: ParametricDQNTrainerParameters = field(
        default_factory=ParametricDQNTrainerParameters)
    net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))

    @property
    def rl_parameters(self):
        return self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> ParametricDQNTrainer:
        net_builder = self.net_builder.value
        # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`.
        self._q_network = net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        # Metrics + reward
        reward_options = reward_options or RewardOptions()
        metrics_to_score = get_metrics_to_score(
            reward_options.metric_reward_values)
        reward_output_dim = len(metrics_to_score) + 1
        reward_network = net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
            output_dim=reward_output_dim,
        )

        q_network_target = self._q_network.get_target_network()
        return ParametricDQNTrainer(
            q_network=self._q_network,
            q_network_target=q_network_target,
            reward_network=reward_network,
            # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
            #  `asdict`.
            **self.trainer_param.asdict(),
        )

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, ParametricDQNTrainer)
        net_builder = self.net_builder.value
        return net_builder.build_serving_module(
            trainer_module.q_network,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
Ejemplo n.º 2
0
class ParametricDQN(ParametricDQNBase):
    __hash__ = param_hash

    trainer_param: ParametricDQNTrainerParameters = field(
        default_factory=ParametricDQNTrainerParameters)
    net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> ParametricDQNTrainer:
        net_builder = self.net_builder.value
        # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`.
        # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`.
        self._q_network = net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        # Metrics + reward
        reward_output_dim = len(self.metrics_to_score) + 1
        reward_network = net_builder.build_q_network(
            self.state_normalization_data,
            self.action_normalization_data,
            output_dim=reward_output_dim,
        )

        if self.use_gpu:
            self._q_network = self._q_network.cuda()
            reward_network = reward_network.cuda()

        q_network_target = self._q_network.get_target_network()
        # pyre-fixme[29]: `Type[ParametricDQNTrainer]` is not a function.
        # pyre-fixme[29]: `Type[ParametricDQNTrainer]` is not a function.
        return ParametricDQNTrainer(
            q_network=self._q_network,
            q_network_target=q_network_target,
            reward_network=reward_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
            #  `asdict`.
            # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
            #  `asdict`.
            **self.trainer_param.asdict(),
        )

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.net_builder.value
        assert self._q_network is not None
        return net_builder.build_serving_module(
            self._q_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
Ejemplo n.º 3
0
class SlateQ(SlateQBase):
    __hash__ = param_hash

    slate_size: int = -1
    num_candidates: int = -1
    trainer_param: SlateQTrainerParameters = field(
        default_factory=SlateQTrainerParameters)
    net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        assert (self.slate_size > 0
                ), f"Please set valid slate_size (currently {self.slate_size})"
        assert (
            self.num_candidates > 0
        ), f"Please set valid num_candidates (currently {self.num_candidates})"
        self.eval_parameters = self.trainer_param.evaluation

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> SlateQTrainer:
        net_builder = self.net_builder.value
        q_network = net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ITEM],
        )

        q_network_target = q_network.get_target_network()
        return SlateQTrainer(
            q_network=q_network,
            q_network_target=q_network_target,
            slate_size=self.slate_size,
            # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, SlateQTrainer)
        net_builder = self.net_builder.value
        return net_builder.build_serving_module(
            trainer_module.q_network,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ITEM],
        )
Ejemplo n.º 4
0
class SlateQ(SlateQBase):
    __hash__ = param_hash

    slate_size: int = -1
    num_candidates: int = -1
    trainer_param: SlateQTrainerParameters = field(
        default_factory=SlateQTrainerParameters
    )
    net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()
        )
    )

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        assert (
            self.slate_size > 0
        ), f"Please set valid slate_size (currently {self.slate_size})"
        assert (
            self.num_candidates > 0
        ), f"Please set valid num_candidates (currently {self.num_candidates})"
        self._q_network: Optional[ModelBase] = None
        self.eval_parameters = self.trainer_param.evaluation

    def build_trainer(self) -> SlateQTrainer:
        net_builder = self.net_builder.value
        # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`.
        # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`.
        self._q_network = net_builder.build_q_network(
            self.state_normalization_data, self.item_normalization_data
        )
        if self.use_gpu:
            self._q_network = self._q_network.cuda()

        q_network_target = self._q_network.get_target_network()
        return SlateQTrainer(
            q_network=self._q_network,
            q_network_target=q_network_target,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.net_builder.value
        assert self._q_network is not None
        return net_builder.build_serving_module(
            self._q_network, self.state_normalization_data, self.item_normalization_data
        )
Ejemplo n.º 5
0
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(
        default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(FullyConnected=
                                                       ValueFullyConnected()))
    use_2_q_functions: bool = True
    serve_mean_policy: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self, use_gpu: bool) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data)

        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return SACReporter()

    def build_serving_module(self) -> Dict[str, torch.nn.Module]:
        assert self._actor_network is not None
        actor_serving_module = self.actor_net_builder.value.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
            serve_mean_policy=self.serve_mean_policy,
        )
        return actor_serving_module
Ejemplo n.º 6
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(
        default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()))
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            self._actor_network.cuda()

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
Ejemplo n.º 7
0
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()
        )
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(
            FullyConnected=ValueFullyConnected()
        )
    )
    use_2_q_functions: bool = True
    serve_mean_policy: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        value_net_builder = self.value_net_builder
        if value_net_builder:
            value_net_builder = value_net_builder.value
            value_network = value_net_builder.build_value_network(
                normalization_data_map[NormalizationKey.STATE]
            )

        trainer = SACTrainer(
            actor_network=actor_network,
            q1_network=q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return None

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, SACTrainer)
        actor_serving_module = self.actor_net_builder.value.build_serving_module(
            trainer_module.actor_network,
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
            serve_mean_policy=self.serve_mean_policy,
        )
        return actor_serving_module
 def test_fully_connected(self):
     # Intentionally used this long path to make sure we included it in __init__.py
     chooser = ParametricDQNNetBuilder__Union(
         FullyConnected=parametric_dqn.fully_connected.FullyConnected()
     )
     self._test_parametric_dqn_net_builder(chooser)
Ejemplo n.º 9
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
Ejemplo n.º 10
0
class SoftActorCritic(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(
        default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(FullyConnected=
                                                       ValueFullyConnected()))
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.get_normalization_data(NormalizationKey.STATE),
            self.get_normalization_data(NormalizationKey.ACTION),
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters,
        ) if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.get_normalization_data(NormalizationKey.STATE))

        if self.use_gpu:
            q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        trainer = SACTrainer(
            q1_network,
            self._actor_network,
            self.trainer_param,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.get_normalization_data(NormalizationKey.STATE),
            self.get_normalization_data(NormalizationKey.ACTION),
        )
Ejemplo n.º 11
0
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()
        )
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(
            FullyConnected=ValueFullyConnected()
        )
    )
    use_2_q_functions: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data
            )

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
Ejemplo n.º 12
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=actor_network,
            q1_network=q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, TD3Trainer)
        net_builder = self.actor_net_builder.value
        return net_builder.build_serving_module(
            trainer_module.actor_network,
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )