def test_awr_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, scaler, action_scaler, augmentation, ): impl = AWRImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, test_with_std=False)
def test_ddpg_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, gamma, tau, n_critics, bootstrap, share_encoder, reguralizing_rate, q_func_type, scaler, augmentation, n_augmentations): impl = DDPGImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, gamma, tau, n_critics, bootstrap, share_encoder, reguralizing_rate, q_func_type=q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_discrete_cql_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, alpha, scaler, reward_scaler, ): impl = DiscreteCQLImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, alpha=alpha, use_gpu=None, scaler=scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=True, deterministic_best_action=q_func_factory != "iqn" )
def test_awac_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, q_func_type, scaler, augmentation, n_augmentations): impl = AWACImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, q_func_type=q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_discrete_awr_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, scaler, augmentation, ): impl = DiscreteAWRImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, test_with_std=False)
def test_sac_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, eps, use_batch_norm, q_func_type, scaler, augmentation, n_augmentations, encoder_params): impl = SACImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, eps, use_batch_norm, q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations, encoder_params=encoder_params) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_double_dqn_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, scaler, augmentation, ): impl = DoubleDQNImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_td3_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, gamma, tau, reguralizing_rate, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, eps, use_batch_norm, q_func_type, scaler, augmentation, n_augmentations, encoder_params): impl = TD3Impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, gamma, tau, reguralizing_rate, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, eps, use_batch_norm, q_func_type=q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations, encoder_params=encoder_params) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_discrete_bcq_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, action_flexibility, beta, scaler, ): impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, action_flexibility=action_flexibility, beta=beta, use_gpu=None, scaler=scaler, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_cql_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, scaler, augmentation, ): impl = DiscreteCQLImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_sac_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, initial_temperature, q_func_type, scaler, augmentation, n_augmentations): impl = DiscreteSACImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, initial_temperature, q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_type != 'iqn')
def test_discrete_awr_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, scaler, reward_scaler, ): impl = DiscreteAWRImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, use_gpu=None, scaler=scaler, action_scaler=None, reward_scaler=reward_scaler, ) torch_impl_tester(impl, discrete=True, test_with_std=False)
def test_cql_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, initial_temperature, initial_alpha, alpha_threshold, conservative_weight, n_action_samples, soft_q_backup, scaler, action_scaler, reward_scaler, ): impl = CQLImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, temp_learning_rate=temp_learning_rate, alpha_learning_rate=alpha_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, temp_optim_factory=temp_optim_factory, alpha_optim_factory=alpha_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, initial_temperature=initial_temperature, initial_alpha=initial_alpha, alpha_threshold=alpha_threshold, conservative_weight=conservative_weight, n_action_samples=n_action_samples, soft_q_backup=soft_q_backup, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_cql_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, n_action_samples, soft_q_backup, scaler, action_scaler, augmentation, ): impl = CQLImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, n_action_samples, soft_q_backup, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_bear_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, mmd_sigma, q_func_type, scaler, augmentation, n_augmentations): impl = BEARImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, encoder_factory, encoder_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, mmd_sigma, q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) impl.build() x = torch.rand(32, *observation_shape) target = impl.compute_target(x) if q_func_type == 'mean': assert target.shape == (32, 1) else: n_quantiles = impl.q_func.q_funcs[0].n_quantiles assert target.shape == (32, n_quantiles) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_plas_with_perturbation_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, lam, beta, action_flexibility, scaler, action_scaler, augmentation, ): impl = PLASWithPerturbationImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, lam, beta, action_flexibility, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_bc_impl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, scaler, augmentation, n_augmentations): impl = BCImpl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=False, imitator=True)
def test_plas_with_perturbation_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, lam, beta, action_flexibility, scaler, action_scaler, reward_scaler, ): impl = PLASWithPerturbationImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, lam=lam, beta=beta, action_flexibility=action_flexibility, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_bc_impl(observation_shape, action_size, learning_rate, eps, use_batch_norm, scaler, augmentation, n_augmentations, encoder_params): impl = BCImpl(observation_shape, action_size, learning_rate, eps, use_batch_norm, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations, encoder_params=encoder_params) torch_impl_tester(impl, discrete=False, imitator=True)
def test_sac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, initial_temperature, scaler, action_scaler, augmentation, ): impl = SACImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, initial_temperature, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_awr_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, momentum, use_batch_norm, scaler, augmentation, n_augmentations, encoder_params): impl = AWRImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, momentum, use_batch_norm, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations, encoder_params=encoder_params) torch_impl_tester(impl, discrete=False, test_with_std=False)
def test_plas_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, lam, beta, scaler, augmentation, ): impl = PLASImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, lam, beta, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_crr_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, beta, n_action_samples, advantage_type, weight_type, max_weight, n_critics, target_reduction_type, scaler, action_scaler, augmentation, ): impl = CRRImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, beta=beta, n_action_samples=n_action_samples, advantage_type=advantage_type, weight_type=weight_type, max_weight=max_weight, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
def test_td3_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, scaler, action_scaler, augmentation, ): impl = TD3Impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_td3_plus_bc_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, target_smoothing_sigma, target_smoothing_clip, alpha, scaler, action_scaler, reward_scaler, ): impl = TD3PlusBCImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, alpha=alpha, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_sac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, initial_temperature, scaler, action_scaler, reward_scaler, ): impl = SACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, temp_learning_rate=temp_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, temp_optim_factory=temp_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, initial_temperature=initial_temperature, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_awac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, target_reduction_type, scaler, action_scaler, augmentation, ): impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, lam=lam, n_action_samples=n_action_samples, max_weight=max_weight, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_dqn_impl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, q_func_type, scaler, augmentation, n_augmentations): impl = DQNImpl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_type != 'iqn')
def test_discrete_bc_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, beta, scaler, ): impl = DiscreteBCImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, beta=beta, use_gpu=None, scaler=scaler, ) torch_impl_tester(impl, discrete=True, imitator=True)
def test_double_dqn_impl(observation_shape, action_size, learning_rate, gamma, n_critics, bootstrap, share_encoder, eps, use_batch_norm, q_func_type, scaler, augmentation, n_augmentations, encoder_params): impl = DoubleDQNImpl(observation_shape, action_size, learning_rate, gamma, n_critics, bootstrap, share_encoder, eps, use_batch_norm, q_func_type=q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations, encoder_params=encoder_params) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_type != 'iqn')