예제 #1
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
예제 #2
0
파일: td3.py 프로젝트: zachkeer/ReAgent
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(
        default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()))
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            self._actor_network.cuda()

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
예제 #3
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=actor_network,
            q1_network=q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, TD3Trainer)
        net_builder = self.actor_net_builder.value
        return net_builder.build_serving_module(
            trainer_module.actor_network,
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )