def test_dirichlet_classification_likelihood(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         noise = torch.rand(6, device=device, dtype=dtype) > 0.5
         noise = noise.long()
         lkhd = DirichletClassificationLikelihood(noise, dtype=dtype)
         # test basics
         self.assertIsInstance(lkhd.noise_covar, FixedGaussianNoise)
         noise = torch.rand(6, device=device, dtype=dtype) > 0.5
         noise = noise.long()
         new_noise, _, _ = lkhd._prepare_targets(noise, dtype=dtype)
         lkhd.noise = new_noise
         self.assertTrue(torch.equal(lkhd.noise, new_noise))
         # test __call__
         mean = torch.zeros(6, device=device, dtype=dtype)
         covar = DiagLazyTensor(torch.ones(6, device=device, dtype=dtype))
         mvn = MultivariateNormal(mean, covar)
         out = lkhd(mvn)
         self.assertTrue(torch.allclose(out.variance, 1 + new_noise))
         # things should break if dimensions mismatch
         mean = torch.zeros(5, device=device, dtype=dtype)
         covar = DiagLazyTensor(torch.ones(5, device=device, dtype=dtype))
         mvn = MultivariateNormal(mean, covar)
         with self.assertWarns(UserWarning):
             lkhd(mvn)
         # test __call__ w/ new targets
         obs_noise = 0.1 + torch.rand(5, device=device, dtype=dtype)
         obs_noise = (obs_noise > 0.5).long()
         out = lkhd(mvn, targets=obs_noise)
         obs_targets, _, _ = lkhd._prepare_targets(obs_noise, dtype=dtype)
         self.assertTrue(torch.allclose(out.variance, 1.0 + obs_targets))
 def create_likelihood(self):
     train_x = torch.randn(15)
     labels = torch.round(train_x).long()
     likelihood = DirichletClassificationLikelihood(labels)
     return likelihood
        num_samples=gp_classif_args["num_examples"],
        dimension=gp_classif_args["dim"],
        num_classes=gp_classif_args["num_classes"])
    x_train, y_train, x_val, y_val = x_train.to(device), y_train.to(
        device), x_val.to(device), y_val.to(device)
    # get dataloaders
    train_loader = DataLoader(torch.utils.data.TensorDataset(x_train, y_train),
                              batch_size=x_train.shape[0],
                              shuffle=True)
    test_loader = DataLoader(torch.utils.data.TensorDataset(x_val, y_val),
                             batch_size=x_val.shape[0],
                             shuffle=False)

    # initialize likelihood and model
    # we let the DirichletClassificationLikelihood compute the targets for us
    likelihood = DirichletClassificationLikelihood(
        y_train.long(), learn_additional_noise=True).to(device)
    model = DirichletGPModel(x_train,
                             likelihood.transformed_targets,
                             likelihood,
                             num_classes=likelihood.num_classes).to(device)

    # Find optimal model hyperparameters
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)