def create_optimizer_mock(trainer_config, reward_signal_config, use_rnn, use_discrete, use_visual): mock_specs = mb.setup_test_behavior_specs( use_discrete, use_visual, vector_action_space=DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE, vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0, ) trainer_settings = trainer_config trainer_settings.reward_signals = reward_signal_config trainer_settings.network_settings.memory = (NetworkSettings.MemorySettings( sequence_length=16, memory_size=10) if use_rnn else None) policy = TFPolicy(0, mock_specs, trainer_settings, "test", False, create_tf_graph=False) if trainer_settings.trainer_type == TrainerType.SAC: optimizer = SACOptimizer(policy, trainer_settings) else: optimizer = PPOOptimizer(policy, trainer_settings) optimizer.policy.initialize() return optimizer
def create_sac_optimizer(self) -> SACOptimizer: if self.framework == FrameworkType.PYTORCH: return TorchSACOptimizer( # type: ignore cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore else: return SACOptimizer( # type: ignore cast(TFPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore
def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual): mock_brain = mb.setup_test_behavior_specs( use_discrete, use_visual, vector_action_space=DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE, vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0, ) trainer_settings = dummy_config trainer_settings.network_settings.memory = (NetworkSettings.MemorySettings( sequence_length=16, memory_size=10) if use_rnn else None) policy = TFPolicy(0, mock_brain, trainer_settings, "test", False, create_tf_graph=False) optimizer = SACOptimizer(policy, trainer_settings) optimizer.policy.initialize() return optimizer