def create_sac_optimizer(self) -> SACOptimizer: if self.framework == FrameworkType.PYTORCH: return TorchSACOptimizer( # type: ignore cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore else: return SACOptimizer( # type: ignore cast(TFPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore
def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual): mock_brain = mb.setup_test_behavior_specs( use_discrete, use_visual, vector_action_space=DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE, vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0, ) trainer_settings = dummy_config trainer_settings.network_settings.memory = (NetworkSettings.MemorySettings( sequence_length=16, memory_size=12) if use_rnn else None) policy = TorchPolicy(0, mock_brain, trainer_settings) optimizer = TorchSACOptimizer(policy, trainer_settings) return optimizer
def create_sac_optimizer(self) -> TorchSACOptimizer: return TorchSACOptimizer( # type: ignore cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore