def get_kcg(models, labeled_data_size, unlabeled_loader):
    models['backbone'].eval()
    with torch.cuda.device(CUDA_VISIBLE_DEVICES):
        features = torch.tensor([]).cuda()

    with torch.no_grad():
        for inputs, _, _ in unlabeled_loader:
            with torch.cuda.device(CUDA_VISIBLE_DEVICES):
                inputs = inputs.cuda()
            _, features_batch, _ = models['backbone'](inputs)
            features = torch.cat((features, features_batch), 0)
        feat = features.detach().cpu().numpy()
        new_av_idx = np.arange(SUBSET, (SUBSET + labeled_data_size))
        sampling = kCenterGreedy(feat)
        batch = sampling.select_batch_(new_av_idx, ADDENDUM)
        other_idx = [x for x in range(SUBSET) if x not in batch]
    return other_idx + batch
def query_samples(model, method, data_unlabeled, subset, labeled_set, cycle,
                  args):

    if method == 'Random':
        arg = np.random.randint(SUBSET, size=SUBSET)

    if (method == 'UncertainGCN') or (method == 'CoreGCN'):
        # Create unlabeled dataloader for the unlabeled subset
        unlabeled_loader = DataLoader(
            data_unlabeled,
            batch_size=BATCH,
            sampler=SubsetSequentialSampler(
                subset + labeled_set
            ),  # more convenient if we maintain the order of subset
            pin_memory=True)
        binary_labels = torch.cat(
            (torch.zeros([SUBSET, 1]), (torch.ones([len(labeled_set), 1]))), 0)

        features = get_features(model, unlabeled_loader)
        features = nn.functional.normalize(features)
        adj = aff_to_adj(features)

        gcn_module = GCN(nfeat=features.shape[1],
                         nhid=args.hidden_units,
                         nclass=1,
                         dropout=args.dropout_rate).cuda()

        models = {'gcn_module': gcn_module}

        optim_backbone = optim.Adam(models['gcn_module'].parameters(),
                                    lr=LR_GCN,
                                    weight_decay=WDECAY)
        optimizers = {'gcn_module': optim_backbone}

        lbl = np.arange(SUBSET, SUBSET + (cycle + 1) * ADDENDUM, 1)
        nlbl = np.arange(0, SUBSET, 1)

        ############
        for _ in range(200):

            optimizers['gcn_module'].zero_grad()
            outputs, _, _ = models['gcn_module'](features, adj)
            lamda = args.lambda_loss
            loss = BCEAdjLoss(outputs, lbl, nlbl, lamda)
            loss.backward()
            optimizers['gcn_module'].step()

        models['gcn_module'].eval()
        with torch.no_grad():
            with torch.cuda.device(CUDA_VISIBLE_DEVICES):
                inputs = features.cuda()
                labels = binary_labels.cuda()
            scores, _, feat = models['gcn_module'](inputs, adj)

            if method == "CoreGCN":
                feat = feat.detach().cpu().numpy()
                new_av_idx = np.arange(SUBSET,
                                       (SUBSET + (cycle + 1) * ADDENDUM))
                sampling2 = kCenterGreedy(feat)
                batch2 = sampling2.select_batch_(new_av_idx, ADDENDUM)
                other_idx = [x for x in range(SUBSET) if x not in batch2]
                arg = other_idx + batch2

            else:

                s_margin = args.s_margin
                scores_median = np.squeeze(
                    torch.abs(scores[:SUBSET] -
                              s_margin).detach().cpu().numpy())
                arg = np.argsort(-(scores_median))

            print("Max confidence value: ", torch.max(scores.data))
            print("Mean confidence value: ", torch.mean(scores.data))
            preds = torch.round(scores)
            correct_labeled = (preds[SUBSET:, 0]
                               == labels[SUBSET:, 0]).sum().item() / (
                                   (cycle + 1) * ADDENDUM)
            correct_unlabeled = (preds[:SUBSET, 0]
                                 == labels[:SUBSET, 0]).sum().item() / SUBSET
            correct = (preds[:, 0]
                       == labels[:, 0]).sum().item() / (SUBSET +
                                                        (cycle + 1) * ADDENDUM)
            print("Labeled classified: ", correct_labeled)
            print("Unlabeled classified: ", correct_unlabeled)
            print("Total classified: ", correct)

    if method == 'CoreSet':
        # Create unlabeled dataloader for the unlabeled subset
        unlabeled_loader = DataLoader(
            data_unlabeled,
            batch_size=BATCH,
            sampler=SubsetSequentialSampler(
                subset + labeled_set
            ),  # more convenient if we maintain the order of subset
            pin_memory=True)

        arg = get_kcg(model, ADDENDUM * (cycle + 1), unlabeled_loader)

    if method == 'lloss':
        # Create unlabeled dataloader for the unlabeled subset
        unlabeled_loader = DataLoader(data_unlabeled,
                                      batch_size=BATCH,
                                      sampler=SubsetSequentialSampler(subset),
                                      pin_memory=True)

        # Measure uncertainty of each data points in the subset
        uncertainty = get_uncertainty(model, unlabeled_loader)
        arg = np.argsort(uncertainty)

    if method == 'VAAL':
        # Create unlabeled dataloader for the unlabeled subset
        unlabeled_loader = DataLoader(data_unlabeled,
                                      batch_size=BATCH,
                                      sampler=SubsetSequentialSampler(subset),
                                      pin_memory=True)
        labeled_loader = DataLoader(
            data_unlabeled,
            batch_size=BATCH,
            sampler=SubsetSequentialSampler(labeled_set),
            pin_memory=True)
        if args.dataset == 'fashionmnist':
            vae = VAE(28, 1, 3)
            discriminator = Discriminator(28)
        else:
            vae = VAE()
            discriminator = Discriminator(32)
        models = {'vae': vae, 'discriminator': discriminator}

        optim_vae = optim.Adam(vae.parameters(), lr=5e-4)
        optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4)
        optimizers = {'vae': optim_vae, 'discriminator': optim_discriminator}

        train_vaal(models, optimizers, labeled_loader, unlabeled_loader,
                   cycle + 1)

        all_preds, all_indices = [], []

        for images, _, indices in unlabeled_loader:
            images = images.cuda()
            with torch.no_grad():
                _, _, mu, _ = vae(images)
                preds = discriminator(mu)

            preds = preds.cpu().data
            all_preds.extend(preds)
            all_indices.extend(indices)

        all_preds = torch.stack(all_preds)
        all_preds = all_preds.view(-1)
        # need to multiply by -1 to be able to use torch.topk
        all_preds *= -1
        # select the points which the discriminator things are the most likely to be unlabeled
        _, arg = torch.sort(all_preds)
    return arg