コード例 #1
0
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(
        default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(FullyConnected=
                                                       ValueFullyConnected()))
    use_2_q_functions: bool = True
    serve_mean_policy: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self, use_gpu: bool) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data)

        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return SACReporter()

    def build_serving_module(self) -> Dict[str, torch.nn.Module]:
        assert self._actor_network is not None
        actor_serving_module = self.actor_net_builder.value.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
            serve_mean_policy=self.serve_mean_policy,
        )
        return actor_serving_module
コード例 #2
0
ファイル: sac.py プロジェクト: vishalbelsare/Horizon
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()
        )
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(
            FullyConnected=ValueFullyConnected()
        )
    )
    use_2_q_functions: bool = True
    serve_mean_policy: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        value_net_builder = self.value_net_builder
        if value_net_builder:
            value_net_builder = value_net_builder.value
            value_network = value_net_builder.build_value_network(
                normalization_data_map[NormalizationKey.STATE]
            )

        trainer = SACTrainer(
            actor_network=actor_network,
            q1_network=q1_network,
            value_network=value_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return None

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, SACTrainer)
        actor_serving_module = self.actor_net_builder.value.build_serving_module(
            trainer_module.actor_network,
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
            serve_mean_policy=self.serve_mean_policy,
        )
        return actor_serving_module
コード例 #3
0
ファイル: td3.py プロジェクト: zachkeer/ReAgent
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(
        default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()))
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data)

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data)
                      if self.use_2_q_functions else None)

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            self._actor_network.cuda()

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
コード例 #4
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
    #  inconsistently.
    def build_trainer(self) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        # pyre-fixme[16]: `TD3` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
コード例 #5
0
 def test_dirichlet_fully_connected(self):
     # Intentionally used this long path to make sure we included it in __init__.py
     chooser = ContinuousActorNetBuilder__Union(
         DirichletFullyConnected=continuous_actor.dirichlet_fully_connected.
         DirichletFullyConnected())
     self._test_actor_net_builder(chooser)
コード例 #6
0
ファイル: soft_actor_critic.py プロジェクト: zhaonann/ReAgent
class SoftActorCritic(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(
        default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()))
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()))
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(FullyConnected=
                                                       ValueFullyConnected()))
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(
        default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.get_normalization_data(NormalizationKey.STATE),
            self.get_normalization_data(NormalizationKey.ACTION),
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters)
        q2_network = (critic_net_builder.build_q_network(
            self.state_normalization_parameters,
            self.action_normalization_parameters,
        ) if self.use_2_q_functions else None)

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.get_normalization_data(NormalizationKey.STATE))

        if self.use_gpu:
            q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        trainer = SACTrainer(
            q1_network,
            self._actor_network,
            self.trainer_param,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`.
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.get_normalization_data(NormalizationKey.STATE),
            self.get_normalization_data(NormalizationKey.ACTION),
        )
コード例 #7
0
class SAC(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            GaussianFullyConnected=GaussianFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=FullyConnected()
        )
    )
    value_net_builder: Optional[ValueNetBuilder__Union] = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ValueNetBuilder__Union(
            FullyConnected=ValueFullyConnected()
        )
    )
    use_2_q_functions: bool = True

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self._actor_network: Optional[ModelBase] = None
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(self) -> SACTrainer:
        actor_net_builder = self.actor_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_actor_network`.
        self._actor_network = actor_net_builder.build_actor(
            self.state_normalization_data, self.action_normalization_data
        )

        critic_net_builder = self.critic_net_builder.value
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        # pyre-fixme[16]: `SAC` has no attribute `_q1_network`.
        self._q1_network = critic_net_builder.build_q_network(
            self.state_normalization_data, self.action_normalization_data
        )
        q2_network = (
            critic_net_builder.build_q_network(
                self.state_normalization_data, self.action_normalization_data
            )
            if self.use_2_q_functions
            else None
        )

        value_network = None
        if self.value_net_builder:
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            # pyre-fixme[16]: `Optional` has no attribute `value`.
            value_net_builder = self.value_net_builder.value
            value_network = value_net_builder.build_value_network(
                self.state_normalization_data
            )

        if self.use_gpu:
            self._q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            self._actor_network.cuda()

        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a
        #  function.
        trainer = SACTrainer(
            actor_network=self._actor_network,
            q1_network=self._q1_network,
            value_network=value_network,
            q2_network=q2_network,
            use_gpu=self.use_gpu,
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def build_serving_module(self) -> torch.nn.Module:
        net_builder = self.actor_net_builder.value
        assert self._actor_network is not None
        return net_builder.build_serving_module(
            self._actor_network,
            self.state_normalization_data,
            self.action_normalization_data,
        )
コード例 #8
0
class TD3(ActorCriticBase):
    __hash__ = param_hash

    trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters)
    actor_net_builder: ContinuousActorNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ContinuousActorNetBuilder__Union(
            FullyConnected=ContinuousFullyConnected()
        )
    )
    critic_net_builder: ParametricDQNNetBuilder__Union = field(
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
        default_factory=lambda: ParametricDQNNetBuilder__Union(
            FullyConnected=ParametricFullyConnected()
        )
    )
    # Why isn't this a parameter in the .yaml config file?
    use_2_q_functions: bool = True
    eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)

    def __post_init_post_parse__(self):
        super().__post_init_post_parse__()
        self.rl_parameters = self.trainer_param.rl

    def build_trainer(
        self,
        normalization_data_map: Dict[str, NormalizationData],
        use_gpu: bool,
        reward_options: Optional[RewardOptions] = None,
    ) -> TD3Trainer:
        actor_net_builder = self.actor_net_builder.value
        actor_network = actor_net_builder.build_actor(
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )

        critic_net_builder = self.critic_net_builder.value
        q1_network = critic_net_builder.build_q_network(
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )
        q2_network = (
            critic_net_builder.build_q_network(
                normalization_data_map[NormalizationKey.STATE],
                normalization_data_map[NormalizationKey.ACTION],
            )
            if self.use_2_q_functions
            else None
        )

        trainer = TD3Trainer(
            actor_network=actor_network,
            q1_network=q1_network,
            q2_network=q2_network,
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`.
            **self.trainer_param.asdict(),
        )
        return trainer

    def get_reporter(self):
        return TD3Reporter()

    def build_serving_module(
        self,
        trainer_module: ReAgentLightningModule,
        normalization_data_map: Dict[str, NormalizationData],
    ) -> torch.nn.Module:
        assert isinstance(trainer_module, TD3Trainer)
        net_builder = self.actor_net_builder.value
        return net_builder.build_serving_module(
            trainer_module.actor_network,
            self.state_feature_config,
            normalization_data_map[NormalizationKey.STATE],
            normalization_data_map[NormalizationKey.ACTION],
        )