def create_ppo_optimizer(self) -> PPOOptimizer: if self.framework == FrameworkType.PYTORCH: return TorchPPOOptimizer( # type: ignore cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore else: return PPOOptimizer( # type: ignore cast(TFPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore
def create_test_ppo_optimizer(dummy_config, use_rnn, use_discrete, use_visual): mock_specs = mb.setup_test_behavior_specs( use_discrete, use_visual, vector_action_space=DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE, vector_obs_space=VECTOR_OBS_SPACE, ) trainer_settings = attr.evolve(dummy_config) trainer_settings.network_settings.memory = (NetworkSettings.MemorySettings( sequence_length=16, memory_size=10) if use_rnn else None) policy = TorchPolicy(0, mock_specs, trainer_settings, "test", False) optimizer = TorchPPOOptimizer(policy, trainer_settings) return optimizer
def create_ppo_optimizer(self) -> TorchPPOOptimizer: return TorchPPOOptimizer( # type: ignore cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore ) # type: ignore