def test_std_network_output_values_with_batch(input_dim, output_dim, hidden_sizes, learn_std): init_std = 2. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, learn_std=learn_std) batch_size = 5 dist = module(torch.ones([batch_size, input_dim])) exp_mean = torch.full( (batch_size, output_dim), input_dim * (torch.Tensor(hidden_sizes).prod().item()), dtype=torch.float) exp_variance = init_std**2 assert dist.mean.equal(exp_mean) assert dist.variance.equal( torch.full((batch_size, output_dim), exp_variance, dtype=torch.float)) assert dist.rsample().shape == (batch_size, output_dim)
def test_softplus_min_std(input_dim, output_dim, hidden_sizes): min_value = 2. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=1., min_std=min_value, hidden_nonlinearity=None, std_parameterization='softplus', hidden_w_init=nn.init.zeros_, output_w_init=nn.init.zeros_) dist = module(torch.ones(input_dim)) exp_variance = torch.Tensor([min_value]).exp().add(1.).log()**2 assert dist.variance.equal(torch.full((output_dim, ), exp_variance[0]))
def test_exp_max_std(input_dim, output_dim, hidden_sizes): max_value = 1. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=10., max_std=max_value, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.zeros_, output_w_init=nn.init.zeros_) dist = module(torch.ones(input_dim)) exp_variance = max_value**2 assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
def test_softplus_max_std(input_dim, output_dim, hidden_sizes): max_value = 1. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=10, max_std=max_value, hidden_nonlinearity=None, std_parameterization='softplus', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) dist = module(torch.ones(input_dim)) exp_variance = torch.Tensor([max_value]).exp().add(1.).log()**2 assert torch.equal( dist.variance, torch.full((output_dim, ), exp_variance[0], dtype=torch.float))
def test_std_network_output_values(input_dim, output_dim, hidden_sizes): init_std = 2. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) dist = module(torch.ones(input_dim)) exp_mean = torch.full( (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item())) exp_variance = init_std**2 assert dist.mean.equal(exp_mean) assert dist.variance.equal(torch.full((output_dim, ), exp_variance)) assert dist.rsample().shape == (output_dim, )
def test_softplus_std_network_output_values(input_dim, output_dim, hidden_sizes): init_std = 2. module = GaussianMLPModule(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='softplus', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) dist = module(torch.ones(input_dim)) exp_mean = input_dim * torch.Tensor(hidden_sizes).prod().item() exp_variance = torch.Tensor([init_std]).exp().add(1.).log()**2 assert dist.mean.equal( torch.full((output_dim, ), exp_mean, dtype=torch.float)) assert dist.variance.equal( torch.full((output_dim, ), exp_variance[0], dtype=torch.float)) assert dist.rsample().shape == (output_dim, )
def test_unknown_std_parameterization(): with pytest.raises(NotImplementedError): GaussianMLPModule(input_dim=1, output_dim=1, std_parameterization='unknown')