def test_discrete_bcq_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, target_reduction_type, action_flexibility, beta, scaler, ): impl = DiscreteBCQImpl( observation_shape=observation_shape, action_size=action_size, learning_rate=learning_rate, optim_factory=optim_factory, encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, n_critics=n_critics, target_reduction_type=target_reduction_type, action_flexibility=action_flexibility, beta=beta, use_gpu=None, scaler=scaler, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_bcq_impl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, q_func_factory, gamma, n_critics, bootstrap, share_encoder, target_reduction_type, action_flexibility, beta, scaler, augmentation, ): impl = DiscreteBCQImpl( observation_shape, action_size, learning_rate, optim_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, n_critics, bootstrap, share_encoder, target_reduction_type, action_flexibility, beta, use_gpu=False, scaler=scaler, augmentation=augmentation, ) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_factory != "iqn")
def test_discrete_bcq_impl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, action_flexibility, beta, q_func_type, scaler, augmentation, n_augmentations): impl = DiscreteBCQImpl(observation_shape, action_size, learning_rate, optim_factory, encoder_factory, gamma, n_critics, bootstrap, share_encoder, action_flexibility, beta, q_func_type, use_gpu=False, scaler=scaler, augmentation=augmentation, n_augmentations=n_augmentations) torch_impl_tester(impl, discrete=True, deterministic_best_action=q_func_type != 'iqn')