def validate(val_loader: DataLoader, model: ImageClassifier, esem, source_classes: list, args: argparse.Namespace) -> float: # switch to evaluate mode model.eval() esem.eval() all_confidece = list() all_consistency = list() all_entropy = list() all_indices = list() all_labels = list() with torch.no_grad(): for i, (images, labels) in enumerate(val_loader): images = images.to(device) labels = labels.to(device) output, f = model(images) values, indices = torch.max(F.softmax(output, -1), 1) yt_1, yt_2, yt_3, yt_4, yt_5 = esem(f) confidece = get_confidence(yt_1, yt_2, yt_3, yt_4, yt_5) entropy = get_entropy(yt_1, yt_2, yt_3, yt_4, yt_5) consistency = get_consistency(yt_1, yt_2, yt_3, yt_4, yt_5) all_confidece.extend(confidece) all_consistency.extend(consistency) all_entropy.extend(entropy) all_indices.extend(indices) all_labels.extend(labels) all_confidece = norm(torch.tensor(all_confidece)) all_consistency = norm(torch.tensor(all_consistency)) all_entropy = norm(torch.tensor(all_entropy)) all_score = (all_confidece + 1 - all_consistency + 1 - all_entropy) / 3 counters = AccuracyCounter(len(source_classes) + 1) for (each_indice, each_label, score) in zip(all_indices, all_labels, all_score): if each_label in source_classes: counters.add_total(each_label) if score >= args.threshold and each_indice == each_label: counters.add_correct(each_label) else: counters.add_total(-1) if score < args.threshold: counters.add_correct(-1) print('---counters---') print(counters.each_accuracy()) print(counters.mean_accuracy()) print(counters.h_score()) return counters.mean_accuracy()
def evaluate_source_common(val_loader: DataLoader, model: ImageClassifier, esem, source_classes: list, args: argparse.Namespace): temperature = 1 # switch to evaluate mode model.eval() esem.eval() common = [] target_private = [] all_confidece = list() all_consistency = list() all_entropy = list() all_labels = list() all_output = list() source_weight = torch.zeros(len(source_classes)).to(device) cnt = 0 with torch.no_grad(): for i, (images, labels) in enumerate(val_loader): images = images.to(device) # labels = labels.to(device) output, f = model(images) output = F.softmax(output, -1) / temperature yt_1, yt_2, yt_3, yt_4, yt_5 = esem(f) confidece = get_confidence(yt_1, yt_2, yt_3, yt_4, yt_5) entropy = get_entropy(yt_1, yt_2, yt_3, yt_4, yt_5) consistency = get_consistency(yt_1, yt_2, yt_3, yt_4, yt_5) all_confidece.extend(confidece) all_consistency.extend(consistency) all_entropy.extend(entropy) all_labels.extend(labels) for each_output in output: all_output.append(each_output) # for (each_output, each_score, label) in zip(output, score, labels): # if each_score >= args.threshold: # source_weight += each_output # cnt += 1 # if label in source_classes: # common.append(each_score) # else: # target_private.append(each_score) all_confidece = norm(torch.tensor(all_confidece)) all_consistency = norm(torch.tensor(all_consistency)) all_entropy = norm(torch.tensor(all_entropy)) all_score = (all_confidece + 1 - all_consistency + 1 - all_entropy) / 3 # args.threshold = torch.median(all_score) # print('threshold = {}'.format(args.threshold)) for i in range(len(all_score)): if all_score[i] >= args.threshold: source_weight += all_output[i] cnt += 1 if all_labels[i] in source_classes: common.append(all_score[i]) else: target_private.append(all_score[i]) hist, bin_edges = np.histogram(common, bins=10, range=(0, 1)) print(hist) # print(bin_edges) hist, bin_edges = np.histogram(target_private, bins=10, range=(0, 1)) print(hist) # print(bin_edges) source_weight = norm(source_weight / cnt) print('---source_weight---') print(source_weight) return source_weight