def __init__(self, dataset, num_classes, batch_size, train, model_num, teachers=[], unlabel_split=.0, use_gpu=False, progress_bar=True): self.data = [] self.teacher_num = len(teachers) devices = get_devices(model_num, use_gpu) if unlabel_split: len_ds = len(dataset) all_indices = range(len_ds) subset_indices = set(random.sample( all_indices, int(len_ds * unlabel_split))) counter = 0 if progress_bar: pbar = tqdm(total=len(dataset), smoothing=.005) if teachers: accuracies = [] # prepare and assert the teachers behaviour _batch = dataset[0][0].unsqueeze(0) for teacher, device in zip(teachers, devices): teacher = teacher.to(device) batch = _batch.clone().to(device) teacher.eval() sample_classes = teacher(batch).shape[1] assert sample_classes == num_classes, f"Num classes of the output is {sample_classes}, {num_classes} required" accuracies.append(RunningAverageMeter()) # create the psuedo labels with torch.no_grad(): for image, label in dataset: unlabelled = 1 if unlabel_split: if counter in subset_indices: unlabelled = 0 counter += 1 _psuedo_labels = [] for i, (teacher, device) in enumerate(zip(teachers, devices)): if use_gpu: image = image.to(device) # add dimension to comply with the desired input dimension (batch of single image) pred = teacher(image.unsqueeze(0)).cpu() _psuedo_labels.append(pred) # keep track of the accuracy of the teacher model acc_at_1 = accuracy( pred, torch.tensor([[label]]), topk=(1,))[0] accuracies[i].update(acc_at_1.item()) image = image.cpu() psuedo_labels = torch.stack(_psuedo_labels, -1).squeeze(0) self.data.append((image, label, psuedo_labels, unlabelled)) if progress_bar: pbar.update(1) if progress_bar: pbar.close() print( f"Accurcies of loaded models are {' ,'.join([str(round(acc.avg, 2))+'%' for acc in accuracies])}, respectively") else: dummy_psuedo_label = torch.empty(num_classes, model_num) for image, label in dataset: unlabelled = 1 if unlabel_split: if counter in subset_indices: unlabelled = 0 counter += 1 self.data.append( (image, label, dummy_psuedo_label, unlabelled) ) if progress_bar: pbar.update(1) if progress_bar: pbar.close()
def test(self, config, best=False, return_results=True): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ losses = RunningAverageMeter() top1 = RunningAverageMeter() top5 = RunningAverageMeter() keep_track_of_results = return_results or self.use_wandb if best: self.load_checkpoints(best=True, inplace=True, verbose=False) if not hasattr(self, 'test_loader'): kwargs = {} if not config.disable_cuda and torch.cuda.is_available(): kwargs = {'num_workers': 4, 'pin_memory': True} data_dict = get_dataset(config.dataset, config.data_dir, 'test') kwargs.update(data_dict) self.test_loader = get_test_loader(batch_size=config.batch_size, **kwargs) if keep_track_of_results: results = {} all_accs = [] for net, model_name in zip(self.nets, self.model_names): net.eval() if self.progress_bar: pbar = tqdm(total=len(self.test_loader.dataset), leave=False, desc=f'Testing {model_name}') for i, (images, labels, _, _) in enumerate(self.test_loader): if self.use_gpu: images, labels = images.cuda(), labels.cuda() images, labels = Variable(images), Variable(labels) # forward pass with torch.no_grad(): outputs = net(images) loss = self.loss_ce(outputs, labels).mean() # measure accuracy and record loss prec_at_1, prec_at_5 = accuracy(outputs.data, labels.data, topk=(1, 5)) losses.update(loss.item(), images.size()[0]) top1.update(prec_at_1.item(), images.size()[0]) top5.update(prec_at_5.item(), images.size()[0]) if self.progress_bar: pbar.update(self.test_loader.batch_size) if self.progress_bar: pbar.write( '[*] {:5}: Test loss: {:.3f}, top1_acc: {:.3f}%, top5_acc: {:.3f}%' .format(model_name, losses.avg, top1.avg, top5.avg)) pbar.close() fold = 'best' if best else 'last' if self.use_wandb: wandb.run.summary[f"{fold} test acc {model_name}"] = top1.avg if keep_track_of_results: results[f'{model_name} test loss'] = losses.avg results[f'{model_name} test acc @ 1'] = top1.avg results[f'{model_name} test acc @ 5'] = top5.avg all_accs.append(top1.avg) if keep_track_of_results: results['average test acc'] = sum(all_accs) / len(all_accs) results['min test acc'] = min(all_accs) results['max test acc'] = max(all_accs) if best: self.load_checkpoints(best=False, inplace=True, verbose=False) if self.use_wandb: wandb.log(results) if return_results: return results