def test_bcq_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, lam, n_action_samples, action_flexibility, latent_size, beta, scaler, action_scaler, ): impl = BCQImpl( observation_shape=observation_shape, action_size=action_size, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, imitator_learning_rate=imitator_learning_rate, actor_optim_factory=actor_optim_factory, critic_optim_factory=critic_optim_factory, imitator_optim_factory=imitator_optim_factory, actor_encoder_factory=encoder_factory, critic_encoder_factory=encoder_factory, imitator_encoder_factory=encoder_factory, q_func_factory=create_q_func_factory(q_func_factory), gamma=gamma, tau=tau, n_critics=n_critics, lam=lam, n_action_samples=n_action_samples, action_flexibility=action_flexibility, latent_size=latent_size, beta=beta, use_gpu=None, scaler=scaler, action_scaler=action_scaler, ) impl.build() # test internal methods x = torch.rand(32, *observation_shape) repeated_x = impl._repeat_observation(x) assert repeated_x.shape == (32, n_action_samples) + observation_shape action = impl._sample_repeated_action(repeated_x) assert action.shape == (32, n_action_samples, action_size) value = impl._predict_value(repeated_x, action) assert value.shape == (n_critics, 32 * n_action_samples, 1) best_action = impl._predict_best_action(x) assert best_action.shape == (32, action_size) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
def test_bcq_impl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, q_func_factory, gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, scaler, augmentation, ): impl = BCQImpl( observation_shape, action_size, actor_learning_rate, critic_learning_rate, imitator_learning_rate, actor_optim_factory, critic_optim_factory, imitator_optim_factory, encoder_factory, encoder_factory, encoder_factory, create_q_func_factory(q_func_factory), gamma, tau, n_critics, bootstrap, share_encoder, lam, n_action_samples, action_flexibility, latent_size, beta, use_gpu=False, scaler=scaler, augmentation=augmentation, ) impl.build() # test internal methods x = torch.rand(32, *observation_shape) repeated_x = impl._repeat_observation(x) assert repeated_x.shape == (32, n_action_samples) + observation_shape action = impl._sample_action(repeated_x) assert action.shape == (32, n_action_samples, action_size) value = impl._predict_value(repeated_x, action) assert value.shape == (n_critics, 32 * n_action_samples, 1) target = impl.compute_target(x) if q_func_factory == "mean": assert target.shape == (32, 1) else: assert target.shape == (32, impl._q_func.q_funcs[0]._n_quantiles) best_action = impl._predict_best_action(x) assert best_action.shape == (32, action_size) torch_impl_tester(impl, discrete=False, deterministic_best_action=False)