def validate(val_loader: DataLoader, model: ImageClassifier, esem,
             source_classes: list, args: argparse.Namespace) -> float:
    # switch to evaluate mode
    model.eval()
    esem.eval()

    all_confidece = list()
    all_consistency = list()
    all_entropy = list()
    all_indices = list()
    all_labels = list()

    with torch.no_grad():
        for i, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            labels = labels.to(device)

            output, f = model(images)
            values, indices = torch.max(F.softmax(output, -1), 1)

            yt_1, yt_2, yt_3, yt_4, yt_5 = esem(f)
            confidece = get_confidence(yt_1, yt_2, yt_3, yt_4, yt_5)
            entropy = get_entropy(yt_1, yt_2, yt_3, yt_4, yt_5)
            consistency = get_consistency(yt_1, yt_2, yt_3, yt_4, yt_5)

            all_confidece.extend(confidece)
            all_consistency.extend(consistency)
            all_entropy.extend(entropy)
            all_indices.extend(indices)
            all_labels.extend(labels)

    all_confidece = norm(torch.tensor(all_confidece))
    all_consistency = norm(torch.tensor(all_consistency))
    all_entropy = norm(torch.tensor(all_entropy))
    all_score = (all_confidece + 1 - all_consistency + 1 - all_entropy) / 3

    counters = AccuracyCounter(len(source_classes) + 1)
    for (each_indice, each_label, score) in zip(all_indices, all_labels,
                                                all_score):
        if each_label in source_classes:
            counters.add_total(each_label)
            if score >= args.threshold and each_indice == each_label:
                counters.add_correct(each_label)
        else:
            counters.add_total(-1)
            if score < args.threshold:
                counters.add_correct(-1)

    print('---counters---')
    print(counters.each_accuracy())
    print(counters.mean_accuracy())
    print(counters.h_score())

    return counters.mean_accuracy()
def evaluate_source_common(val_loader: DataLoader, model: ImageClassifier,
                           esem, source_classes: list,
                           args: argparse.Namespace):
    temperature = 1
    # switch to evaluate mode
    model.eval()
    esem.eval()

    common = []
    target_private = []

    all_confidece = list()
    all_consistency = list()
    all_entropy = list()
    all_labels = list()
    all_output = list()

    source_weight = torch.zeros(len(source_classes)).to(device)
    cnt = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            # labels = labels.to(device)

            output, f = model(images)
            output = F.softmax(output, -1) / temperature
            yt_1, yt_2, yt_3, yt_4, yt_5 = esem(f)
            confidece = get_confidence(yt_1, yt_2, yt_3, yt_4, yt_5)
            entropy = get_entropy(yt_1, yt_2, yt_3, yt_4, yt_5)
            consistency = get_consistency(yt_1, yt_2, yt_3, yt_4, yt_5)

            all_confidece.extend(confidece)
            all_consistency.extend(consistency)
            all_entropy.extend(entropy)
            all_labels.extend(labels)

            for each_output in output:
                all_output.append(each_output)

            # for (each_output, each_score, label) in zip(output, score, labels):
            #     if each_score >= args.threshold:
            #         source_weight += each_output
            #         cnt += 1
            # if label in source_classes:
            #     common.append(each_score)
            # else:
            #     target_private.append(each_score)

    all_confidece = norm(torch.tensor(all_confidece))
    all_consistency = norm(torch.tensor(all_consistency))
    all_entropy = norm(torch.tensor(all_entropy))
    all_score = (all_confidece + 1 - all_consistency + 1 - all_entropy) / 3

    # args.threshold = torch.median(all_score)
    # print('threshold = {}'.format(args.threshold))

    for i in range(len(all_score)):
        if all_score[i] >= args.threshold:
            source_weight += all_output[i]
            cnt += 1
        if all_labels[i] in source_classes:
            common.append(all_score[i])
        else:
            target_private.append(all_score[i])

    hist, bin_edges = np.histogram(common, bins=10, range=(0, 1))
    print(hist)
    # print(bin_edges)

    hist, bin_edges = np.histogram(target_private, bins=10, range=(0, 1))
    print(hist)
    # print(bin_edges)

    source_weight = norm(source_weight / cnt)
    print('---source_weight---')
    print(source_weight)
    return source_weight