def test_awac_impl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, q_func_type, scaler, augmentation, n_augmentations): impl = AWACImpl(observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, encoder_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, bootstrap, share_encoder, q_func_type=q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=False, deterministic_best_action=q_func_type != 'iqn')
def test_awac_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, actor_optim_factory, critic_optim_factory, encoder_factory, q_func_factory, gamma, tau, lam, n_action_samples, max_weight, n_critics, target_reduction_type, scaler, action_scaler, augmentation, ): impl = AWACImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, lam=lam, n_action_samples=n_action_samples, max_weight=max_weight, n_critics=n_critics, target_reduction_type=target_reduction_type, use_gpu=None, scaler=scaler, action_scaler=action_scaler, augmentation=augmentation, ) torch_impl_tester( impl, discrete=False, deterministic_best_action=q_func_factory != "iqn" )