def test_discrete_fqe_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, scaler, ): fqe = DiscreteFQEImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, use_gpu=False, scaler=scaler, ) torch_impl_tester(fqe, True)
def test_create_q_func_factory(name): factory = create_q_func_factory(name) if name == "mean": assert isinstance(factory, MeanQFunctionFactory) elif name == "qr": assert isinstance(factory, QRQFunctionFactory) elif name == "iqn": assert isinstance(factory, IQNQFunctionFactory) elif name == "fqf": assert isinstance(factory, FQFQFunctionFactory)
def test_plas_with_perturbation_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, lam, beta, action_flexibility, scaler, augmentation, ): impl = PLASWithPerturbationImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, lam, beta, action_flexibility, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )
def test_awac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, scaler, augmentation, ): impl = AWACImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_td3_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, scaler, augmentation, ): impl = TD3Impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, target_smoothing_sigma, target_smoothing_clip, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_bear_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, mmd_sigma, scaler, augmentation, ): impl = BEARImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, temp_learning_rate, alpha_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, temp_optim_factory, alpha_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, initial_temperature, initial_alpha, alpha_threshold, lam, n_action_samples, mmd_sigma, use_gpu=False, scaler=scaler, augmentation=augmentation, ) impl.build() x = torch.rand(32, *observation_shape) target = impl.compute_target(x) if q_func_factory == "mean": assert target.shape == (32, 1) else: n_quantiles = impl._q_func.q_funcs[0]._n_quantiles assert target.shape == (32, n_quantiles) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_factory != "iqn")
def test_bcq_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, scaler, augmentation, ): impl = BCQImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, use_gpu=False, scaler=scaler, augmentation=augmentation, ) impl.build() # test internal methods x = torch.rand(32, *observation_shape) repeated_x = impl._repeat_observation(x) assert repeated_x.shape == (32, n_action_samples) + observation_shape action = impl._sample_action(repeated_x) assert action.shape == (32, n_action_samples, action_size) value = impl._predict_value(repeated_x, action) assert value.shape == (n_critics, 32 * n_action_samples, 1) target = impl.compute_target(x) if q_func_factory == "mean": assert target.shape == (32, 1) else: assert target.shape == (32, impl._q_func.q_funcs[0]._n_quantiles) best_action = impl._predict_best_action(x) assert best_action.shape == (32, action_size) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)