def test_discrete_cql_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, alpha, scaler, reward_scaler, ): impl = DiscreteCQLImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, alpha=alpha, use_gpu=None, scaler=scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=True, deterministic_best_action=q_func_factory != "iqn" )
def test_discrete_fqe_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, scaler, ): fqe = DiscreteFQEImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, use_gpu=None, scaler=scaler, action_scaler=None, ) torch_impl_tester(fqe, True)
def test_discrete_fqe_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, scaler, ): fqe = DiscreteFQEImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, use_gpu=False, scaler=scaler, action_scaler=None, ) torch_impl_tester(fqe, True)
def test_double_dqn_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, scaler, augmentation, ): impl = DoubleDQNImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_bcq_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, action_flexibility, beta, scaler, ): impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, action_flexibility=action_flexibility, beta=beta, use_gpu=None, scaler=scaler, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_cql_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, scaler, augmentation, ): impl = DiscreteCQLImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_cql_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, initial_temperature, initial_alpha, alpha_threshold, conservative_weight, n_action_samples, soft_q_backup, scaler, action_scaler, reward_scaler, ): impl = CQLImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, temp_learning_rate=temp_learning_rate, alpha_learning_rate=alpha_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, temp_optim_factory=temp_optim_factory, alpha_optim_factory=alpha_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, initial_temperature=initial_temperature, initial_alpha=initial_alpha, alpha_threshold=alpha_threshold, conservative_weight=conservative_weight, n_action_samples=n_action_samples, soft_q_backup=soft_q_backup, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_cql_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, n_action_samples, soft_q_backup, scaler, action_scaler, augmentation, ): impl = CQLImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, n_action_samples, soft_q_backup, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_create_q_func_factory(name): factory = create_q_func_factory(name) if name == "mean": assert isinstance(factory, MeanQFunctionFactory) elif name == "qr": assert isinstance(factory, QRQFunctionFactory) elif name == "iqn": assert isinstance(factory, IQNQFunctionFactory) elif name == "fqf": assert isinstance(factory, FQFQFunctionFactory)
def test_plas_with_perturbation_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, lam, beta, action_flexibility, scaler, action_scaler, augmentation, ): impl = PLASWithPerturbationImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, lam, beta, action_flexibility, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_plas_with_perturbation_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, lam, beta, action_flexibility, scaler, action_scaler, reward_scaler, ): impl = PLASWithPerturbationImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, lam=lam, beta=beta, action_flexibility=action_flexibility, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_awac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, target_reduction_type, scaler, action_scaler, augmentation, ): impl = AWACImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, target_reduction_type, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_td3_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, target_smoothing_sigma, target_smoothing_clip, scaler, action_scaler, augmentation, ): impl = TD3Impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_reduction_type, target_smoothing_sigma, target_smoothing_clip, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_sac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, scaler, action_scaler, augmentation, ): impl = SACImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_crr_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, beta, n_action_samples, advantage_type, weight_type, max_weight, n_critics, target_reduction_type, scaler, action_scaler, augmentation, ): impl = CRRImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, beta=beta, n_action_samples=n_action_samples, advantage_type=advantage_type, weight_type=weight_type, max_weight=max_weight, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
def test_td3_plus_bc_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, target_smoothing_sigma, target_smoothing_clip, alpha, scaler, action_scaler, reward_scaler, ): impl = TD3PlusBCImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, alpha=alpha, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_awac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, target_reduction_type, scaler, action_scaler, reward_scaler, ): impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, lam=lam, n_action_samples=n_action_samples, max_weight=max_weight, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, action_scaler=action_scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_sac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, target_reduction_type, initial_temperature, scaler, action_scaler, augmentation, ): impl = SACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, temp_learning_rate=temp_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, temp_optim_factory=temp_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, target_reduction_type=target_reduction_type, initial_temperature=initial_temperature, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_sac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, temp_learning_rate, actor_optim_factory, critic_optim_factory, temp_optim_factory, encoder_factory, q_func_factory, gamma, n_critics, initial_temperature, scaler, reward_scaler, ): impl = DiscreteSACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, temp_learning_rate=temp_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, temp_optim_factory=temp_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, initial_temperature=initial_temperature, use_gpu=None, scaler=scaler, reward_scaler=reward_scaler, ) torch_impl_tester( impl, discrete=True, deterministic_best_action=q_func_factory != "iqn" )
def test_discrete_bcq_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, target_reduction_type, action_flexibility, beta, scaler, augmentation, ): impl = DiscreteBCQImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, target_reduction_type, action_flexibility, beta, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_bcq_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, lam, n_action_samples, action_flexibility, latent_size, beta, scaler, action_scaler, ): impl = BCQImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, lam=lam, n_action_samples=n_action_samples, action_flexibility=action_flexibility, latent_size=latent_size, beta=beta, use_gpu=None, scaler=scaler, action_scaler=action_scaler, ) impl.build() # test internal methods x = torch.rand(32, *observation_shape) repeated_x = impl._repeat_observation(x) assert repeated_x.shape == (32, n_action_samples) + observation_shape action = impl._sample_repeated_action(repeated_x) assert action.shape == (32, n_action_samples, action_size) value = impl._predict_value(repeated_x, action) assert value.shape == (n_critics, 32 * n_action_samples, 1) best_action = impl._predict_best_action(x) assert best_action.shape == (32, action_size) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
def test_bear_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, mmd_kernel, mmd_sigma, scaler, action_scaler, augmentation, ): impl = BEARImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, temp_learning_rate=temp_learning_rate, alpha_learning_rate=alpha_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, temp_optim_factory=temp_optim_factory, alpha_optim_factory=alpha_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, initial_temperature=initial_temperature, initial_alpha=initial_alpha, alpha_threshold=alpha_threshold, lam=lam, n_action_samples=n_action_samples, mmd_kernel=mmd_kernel, mmd_sigma=mmd_sigma, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) impl.build() x = torch.rand(32, *observation_shape) target = impl.compute_target(x) if q_func_factory == "mean": assert target.shape == (32, 1) else: n_quantiles = impl._q_func.q_funcs[0]._n_quantiles assert target.shape == (32, n_quantiles) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_bear_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, n_target_samples, n_mmd_action_samples, mmd_kernel, mmd_sigma, vae_kl_weight, scaler, action_scaler, ): impl = BEARImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, temp_learning_rate=temp_learning_rate, alpha_learning_rate=alpha_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, temp_optim_factory=temp_optim_factory, alpha_optim_factory=alpha_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, initial_temperature=initial_temperature, initial_alpha=initial_alpha, alpha_threshold=alpha_threshold, lam=lam, n_action_samples=n_action_samples, n_target_samples=n_target_samples, n_mmd_action_samples=n_mmd_action_samples, mmd_kernel=mmd_kernel, mmd_sigma=mmd_sigma, vae_kl_weight=vae_kl_weight, use_gpu=None, scaler=scaler, action_scaler=action_scaler, ) impl.build() torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
def test_bcq_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, scaler, action_scaler, augmentation, ): impl = BCQImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, use_gpu=False, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) impl.build() # test internal methods x = torch.rand(32, *observation_shape) repeated_x = impl._repeat_observation(x) assert repeated_x.shape == (32, n_action_samples) + observation_shape action = impl._sample_repeated_action(repeated_x) assert action.shape == (32, n_action_samples, action_size) value = impl._predict_value(repeated_x, action) assert value.shape == (n_critics, 32 * n_action_samples, 1) target = impl.compute_target(x) if q_func_factory == "mean": assert target.shape == (32, 1) else: assert target.shape == (32, impl._q_func.q_funcs[0]._n_quantiles) best_action = impl._predict_best_action(x) assert best_action.shape == (32, action_size) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)