class ParametricDQN(ParametricDQNBase): __hash__ = param_hash trainer_param: ParametricDQNTrainerParameters = field( default_factory=ParametricDQNTrainerParameters) net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected())) @property def rl_parameters(self): return self.trainer_param.rl def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> ParametricDQNTrainer: net_builder = self.net_builder.value # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`. self._q_network = net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) # Metrics + reward reward_options = reward_options or RewardOptions() metrics_to_score = get_metrics_to_score( reward_options.metric_reward_values) reward_output_dim = len(metrics_to_score) + 1 reward_network = net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], output_dim=reward_output_dim, ) q_network_target = self._q_network.get_target_network() return ParametricDQNTrainer( q_network=self._q_network, q_network_target=q_network_target, reward_network=reward_network, # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute # `asdict`. **self.trainer_param.asdict(), ) def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, ParametricDQNTrainer) net_builder = self.net_builder.value return net_builder.build_serving_module( trainer_module.q_network, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], )
class ParametricDQN(ParametricDQNBase): __hash__ = param_hash trainer_param: ParametricDQNTrainerParameters = field( default_factory=ParametricDQNTrainerParameters) net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected())) def __post_init_post_parse__(self): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl def build_trainer(self) -> ParametricDQNTrainer: net_builder = self.net_builder.value # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`. # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`. self._q_network = net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) # Metrics + reward reward_output_dim = len(self.metrics_to_score) + 1 reward_network = net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data, output_dim=reward_output_dim, ) if self.use_gpu: self._q_network = self._q_network.cuda() reward_network = reward_network.cuda() q_network_target = self._q_network.get_target_network() # pyre-fixme[29]: `Type[ParametricDQNTrainer]` is not a function. # pyre-fixme[29]: `Type[ParametricDQNTrainer]` is not a function. return ParametricDQNTrainer( q_network=self._q_network, q_network_target=q_network_target, reward_network=reward_network, use_gpu=self.use_gpu, # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute # `asdict`. # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute # `asdict`. **self.trainer_param.asdict(), ) def build_serving_module(self) -> torch.nn.Module: net_builder = self.net_builder.value assert self._q_network is not None return net_builder.build_serving_module( self._q_network, self.state_normalization_data, self.action_normalization_data, )
class SlateQ(SlateQBase): __hash__ = param_hash slate_size: int = -1 num_candidates: int = -1 trainer_param: SlateQTrainerParameters = field( default_factory=SlateQTrainerParameters) net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected())) def __post_init_post_parse__(self): super().__post_init_post_parse__() assert (self.slate_size > 0 ), f"Please set valid slate_size (currently {self.slate_size})" assert ( self.num_candidates > 0 ), f"Please set valid num_candidates (currently {self.num_candidates})" self.eval_parameters = self.trainer_param.evaluation def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> SlateQTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ITEM], ) q_network_target = q_network.get_target_network() return SlateQTrainer( q_network=q_network, q_network_target=q_network_target, slate_size=self.slate_size, # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, SlateQTrainer) net_builder = self.net_builder.value return net_builder.build_serving_module( trainer_module.q_network, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ITEM], )
class SlateQ(SlateQBase): __hash__ = param_hash slate_size: int = -1 num_candidates: int = -1 trainer_param: SlateQTrainerParameters = field( default_factory=SlateQTrainerParameters ) net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected() ) ) def __post_init_post_parse__(self): super().__post_init_post_parse__() assert ( self.slate_size > 0 ), f"Please set valid slate_size (currently {self.slate_size})" assert ( self.num_candidates > 0 ), f"Please set valid num_candidates (currently {self.num_candidates})" self._q_network: Optional[ModelBase] = None self.eval_parameters = self.trainer_param.evaluation def build_trainer(self) -> SlateQTrainer: net_builder = self.net_builder.value # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`. # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`. self._q_network = net_builder.build_q_network( self.state_normalization_data, self.item_normalization_data ) if self.use_gpu: self._q_network = self._q_network.cuda() q_network_target = self._q_network.get_target_network() return SlateQTrainer( q_network=self._q_network, q_network_target=q_network_target, use_gpu=self.use_gpu, # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) def build_serving_module(self) -> torch.nn.Module: net_builder = self.net_builder.value assert self._q_network is not None return net_builder.build_serving_module( self._q_network, self.state_normalization_data, self.item_normalization_data )
class SAC(ActorCriticBase): __hash__ = param_hash trainer_param: SACTrainerParameters = field( default_factory=SACTrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( GaussianFullyConnected=GaussianFullyConnected())) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected())) value_net_builder: Optional[ValueNetBuilder__Union] = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ValueNetBuilder__Union(FullyConnected= ValueFullyConnected())) use_2_q_functions: bool = True serve_mean_policy: bool = True def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager` # inconsistently. # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager` # inconsistently. def build_trainer(self, use_gpu: bool) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) q2_network = (critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) if self.use_2_q_functions else None) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.state_normalization_data) trainer = SACTrainer( actor_network=self._actor_network, q1_network=self._q1_network, value_network=value_network, q2_network=q2_network, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def get_reporter(self): return SACReporter() def build_serving_module(self) -> Dict[str, torch.nn.Module]: assert self._actor_network is not None actor_serving_module = self.actor_net_builder.value.build_serving_module( self._actor_network, self.state_normalization_data, self.action_normalization_data, serve_mean_policy=self.serve_mean_policy, ) return actor_serving_module
class TD3(ActorCriticBase): __hash__ = param_hash trainer_param: TD3TrainerParameters = field( default_factory=TD3TrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( FullyConnected=ContinuousFullyConnected())) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=ParametricFullyConnected())) use_2_q_functions: bool = True eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl def build_trainer(self) -> TD3Trainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) q2_network = (critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) if self.use_2_q_functions else None) if self.use_gpu: self._q1_network.cuda() if q2_network: q2_network.cuda() self._actor_network.cuda() trainer = TD3Trainer( actor_network=self._actor_network, q1_network=self._q1_network, q2_network=q2_network, use_gpu=self.use_gpu, # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value assert self._actor_network is not None return net_builder.build_serving_module( self._actor_network, self.state_normalization_data, self.action_normalization_data, )
class SAC(ActorCriticBase): __hash__ = param_hash trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( GaussianFullyConnected=GaussianFullyConnected() ) ) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected() ) ) value_net_builder: Optional[ValueNetBuilder__Union] = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ValueNetBuilder__Union( FullyConnected=ValueFullyConnected() ) ) use_2_q_functions: bool = True serve_mean_policy: bool = True def __post_init_post_parse__(self): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> SACTrainer: actor_net_builder = self.actor_net_builder.value actor_network = actor_net_builder.build_actor( self.state_feature_config, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) critic_net_builder = self.critic_net_builder.value q1_network = critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) q2_network = ( critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) if self.use_2_q_functions else None ) value_network = None value_net_builder = self.value_net_builder if value_net_builder: value_net_builder = value_net_builder.value value_network = value_net_builder.build_value_network( normalization_data_map[NormalizationKey.STATE] ) trainer = SACTrainer( actor_network=actor_network, q1_network=q1_network, value_network=value_network, q2_network=q2_network, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def get_reporter(self): return None def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, SACTrainer) actor_serving_module = self.actor_net_builder.value.build_serving_module( trainer_module.actor_network, self.state_feature_config, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], serve_mean_policy=self.serve_mean_policy, ) return actor_serving_module
def test_fully_connected(self): # Intentionally used this long path to make sure we included it in __init__.py chooser = ParametricDQNNetBuilder__Union( FullyConnected=parametric_dqn.fully_connected.FullyConnected() ) self._test_parametric_dqn_net_builder(chooser)
class TD3(ActorCriticBase): __hash__ = param_hash trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( FullyConnected=ContinuousFullyConnected() ) ) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=ParametricFullyConnected() ) ) # Why isn't this a parameter in the .yaml config file? use_2_q_functions: bool = True eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl # pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager` # inconsistently. def build_trainer(self) -> TD3Trainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None ) trainer = TD3Trainer( actor_network=self._actor_network, q1_network=self._q1_network, q2_network=q2_network, # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def get_reporter(self): return TD3Reporter() def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value assert self._actor_network is not None return net_builder.build_serving_module( self._actor_network, self.state_normalization_data, self.action_normalization_data, )
class SoftActorCritic(ActorCriticBase): __hash__ = param_hash trainer_param: SACTrainerParameters = field( default_factory=SACTrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( GaussianFullyConnected=GaussianFullyConnected())) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected())) value_net_builder: Optional[ValueNetBuilder__Union] = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ValueNetBuilder__Union(FullyConnected= ValueFullyConnected())) use_2_q_functions: bool = True eval_parameters: EvaluationParameters = field( default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.get_normalization_data(NormalizationKey.STATE), self.get_normalization_data(NormalizationKey.ACTION), ) critic_net_builder = self.critic_net_builder.value q1_network = critic_net_builder.build_q_network( self.state_normalization_parameters, self.action_normalization_parameters) q2_network = (critic_net_builder.build_q_network( self.state_normalization_parameters, self.action_normalization_parameters, ) if self.use_2_q_functions else None) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.get_normalization_data(NormalizationKey.STATE)) if self.use_gpu: q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() self._actor_network.cuda() trainer = SACTrainer( q1_network, self._actor_network, self.trainer_param, value_network=value_network, q2_network=q2_network, use_gpu=self.use_gpu, ) return trainer def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. assert self._actor_network is not None return net_builder.build_serving_module( self._actor_network, self.get_normalization_data(NormalizationKey.STATE), self.get_normalization_data(NormalizationKey.ACTION), )
class SAC(ActorCriticBase): __hash__ = param_hash trainer_param: SACTrainerParameters = field(default_factory=SACTrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `GaussianFullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( GaussianFullyConnected=GaussianFullyConnected() ) ) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=FullyConnected() ) ) value_net_builder: Optional[ValueNetBuilder__Union] = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ValueNetBuilder__Union( FullyConnected=ValueFullyConnected() ) ) use_2_q_functions: bool = True def __post_init_post_parse__(self): super().__post_init_post_parse__() self._actor_network: Optional[ModelBase] = None self.rl_parameters = self.trainer_param.rl def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None ) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.state_normalization_data ) if self.use_gpu: self._q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() self._actor_network.cuda() # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a # function. # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a # function. trainer = SACTrainer( actor_network=self._actor_network, q1_network=self._q1_network, value_network=value_network, q2_network=q2_network, use_gpu=self.use_gpu, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value assert self._actor_network is not None return net_builder.build_serving_module( self._actor_network, self.state_normalization_data, self.action_normalization_data, )
class TD3(ActorCriticBase): __hash__ = param_hash trainer_param: TD3TrainerParameters = field(default_factory=TD3TrainerParameters) actor_net_builder: ContinuousActorNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ContinuousActorNetBuilder__Union( FullyConnected=ContinuousFullyConnected() ) ) critic_net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. default_factory=lambda: ParametricDQNNetBuilder__Union( FullyConnected=ParametricFullyConnected() ) ) # Why isn't this a parameter in the .yaml config file? use_2_q_functions: bool = True eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters) def __post_init_post_parse__(self): super().__post_init_post_parse__() self.rl_parameters = self.trainer_param.rl def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> TD3Trainer: actor_net_builder = self.actor_net_builder.value actor_network = actor_net_builder.build_actor( self.state_feature_config, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) critic_net_builder = self.critic_net_builder.value q1_network = critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) q2_network = ( critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) if self.use_2_q_functions else None ) trainer = TD3Trainer( actor_network=actor_network, q1_network=q1_network, q2_network=q2_network, # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer def get_reporter(self): return TD3Reporter() def build_serving_module( self, trainer_module: ReAgentLightningModule, normalization_data_map: Dict[str, NormalizationData], ) -> torch.nn.Module: assert isinstance(trainer_module, TD3Trainer) net_builder = self.actor_net_builder.value return net_builder.build_serving_module( trainer_module.actor_network, self.state_feature_config, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], )