def test_mtsac_get_log_alpha_incorrect_num_tasks(monkeypatch): """Check that if the num_tasks passed does not match the number of tasks in the environment, then the algorithm should raise an exception. MTSAC uses disentangled alphas, meaning that """ env_names = ['CartPole-v0', 'CartPole-v1'] task_envs = [GymEnv(name, max_episode_length=150) for name in env_names] env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy) deterministic.set_seed(0) policy = TanhGaussianMLPPolicy( env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=torch.nn.ReLU, output_nonlinearity=None, min_std=np.exp(-20.), max_std=np.exp(2.), ) qf1 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) qf2 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) buffer_batch_size = 2 mtsac = MTSAC(policy=policy, qf1=qf1, qf2=qf2, sampler=None, gradient_steps_per_itr=150, eval_env=[env], env_spec=env.spec, num_tasks=4, steps_per_epoch=5, replay_buffer=replay_buffer, min_buffer_size=1e3, target_update_tau=5e-3, discount=0.99, buffer_batch_size=buffer_batch_size) monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.])) error_string = ('The number of tasks in the environment does ' 'not match self._num_tasks. Are you sure that you passed ' 'The correct number of tasks?') obs = torch.Tensor([env.reset()[0]] * buffer_batch_size) with pytest.raises(ValueError, match=error_string): mtsac._get_log_alpha(dict(observation=obs))
def test_mtsac_get_log_alpha(monkeypatch): """Check that the private function _get_log_alpha functions correctly. MTSAC uses disentangled alphas, meaning that """ env_names = ['CartPole-v0', 'CartPole-v1'] task_envs = [GarageEnv(env_name=name) for name in env_names] env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy) deterministic.set_seed(0) policy = TanhGaussianMLPPolicy( env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=torch.nn.ReLU, output_nonlinearity=None, min_std=np.exp(-20.), max_std=np.exp(2.), ) qf1 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) qf2 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) num_tasks = 2 buffer_batch_size = 2 mtsac = MTSAC(policy=policy, qf1=qf1, qf2=qf2, gradient_steps_per_itr=150, max_path_length=150, eval_env=env, env_spec=env.spec, num_tasks=num_tasks, steps_per_epoch=5, replay_buffer=replay_buffer, min_buffer_size=1e3, target_update_tau=5e-3, discount=0.99, buffer_batch_size=buffer_batch_size) monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.])) for i, _ in enumerate(env_names): obs = torch.Tensor([env.reset()] * buffer_batch_size) log_alpha = mtsac._get_log_alpha(dict(observation=obs)) assert (log_alpha == torch.Tensor([i + 1, i + 1])).all().item() assert log_alpha.size() == torch.Size([mtsac._buffer_batch_size])