def build_trainer(self, use_gpu: bool) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) q2_network = (critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data) if self.use_2_q_functions else None) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.state_normalization_data) trainer = SACTrainer( actor_network=self._actor_network, q1_network=self._q1_network, value_network=value_network, q2_network=q2_network, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer
def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None ) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.state_normalization_data ) if self.use_gpu: self._q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() self._actor_network.cuda() # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a # function. # pyre-fixme[29]: `Type[reagent.training.sac_trainer.SACTrainer]` is not a # function. trainer = SACTrainer( actor_network=self._actor_network, q1_network=self._q1_network, value_network=value_network, q2_network=q2_network, use_gpu=self.use_gpu, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer
def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. self._q1_network = critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None ) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.state_normalization_data ) if self.use_gpu: self._q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() self._actor_network.cuda() trainer = SACTrainer( self._q1_network, self._actor_network, self.trainer_param, value_network=value_network, q2_network=q2_network, use_gpu=self.use_gpu, ) return trainer
def build_trainer( self, normalization_data_map: Dict[str, NormalizationData], use_gpu: bool, reward_options: Optional[RewardOptions] = None, ) -> SACTrainer: actor_net_builder = self.actor_net_builder.value actor_network = actor_net_builder.build_actor( self.state_feature_config, normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) critic_net_builder = self.critic_net_builder.value q1_network = critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) q2_network = ( critic_net_builder.build_q_network( normalization_data_map[NormalizationKey.STATE], normalization_data_map[NormalizationKey.ACTION], ) if self.use_2_q_functions else None ) value_network = None value_net_builder = self.value_net_builder if value_net_builder: value_net_builder = value_net_builder.value value_network = value_net_builder.build_value_network( normalization_data_map[NormalizationKey.STATE] ) trainer = SACTrainer( actor_network=actor_network, q1_network=q1_network, value_network=value_network, q2_network=q2_network, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer
def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. # pyre-fixme[16]: `SoftActorCritic` has no attribute `_actor_network`. self._actor_network = actor_net_builder.build_actor( self.get_normalization_data(NormalizationKey.STATE), self.get_normalization_data(NormalizationKey.ACTION), ) critic_net_builder = self.critic_net_builder.value q1_network = critic_net_builder.build_q_network( self.state_normalization_parameters, self.action_normalization_parameters) q2_network = (critic_net_builder.build_q_network( self.state_normalization_parameters, self.action_normalization_parameters, ) if self.use_2_q_functions else None) value_network = None if self.value_net_builder: # pyre-fixme[16]: `Optional` has no attribute `value`. # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( self.get_normalization_data(NormalizationKey.STATE)) if self.use_gpu: q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() self._actor_network.cuda() trainer = SACTrainer( q1_network, self._actor_network, self.trainer_param, value_network=value_network, q2_network=q2_network, use_gpu=self.use_gpu, ) return trainer