def test(args): # 1. get dataset train_loader, val_loader, test_loader, class_list = return_dataset(args) # 2. generator if args.net == 'resnet50': G = ResBase50() inc = 2048 elif args.net == 'resnet101': G = ResBase101() inc = 2048 elif args.net == "alexnet": G = AlexNetBase() inc = 4096 elif args.net == "vgg": G = VGGBase() inc = 4096 elif args.net == "inception_v3": G = models.inception_v3(pretrained=True) inc = 1000 elif args.net == "googlenet": G = models.googlenet(pretrained = True) inc = 1000 elif args.net == "densenet": G = models.densenet161(pretrained = True) inc = 1000 elif args.net == "resnext": G = models.resnext50_32x4d(pretrained = True) inc = 1000 elif args.net == "squeezenet": G = models.squeezenet1_0(pretrained = True) inc = 1000 else: raise ValueError('Model cannot be recognized.') G.cuda() G.train() # 3. classifier F = Predictor(num_class=len(class_list), inc=inc, temp=args.T) F.cuda() F.train() # 4. load pre-trained model G.load_state_dict(torch.load(os.path.join(args.checkpath, "G_net_{}_loss_{}.pth".format(args.net, args.loss)))) F.load_state_dict(torch.load(os.path.join(args.checkpath, "F_net_{}_loss_{}.pth".format(args.net, args.loss)))) # 5. testing acc_test = eval(test_loader, G, F, class_list) print('Testing accuracy: {:.3f}\n'.format(acc_test)) return acc_test
def test_ensemble(args, alphas=None): # 1. get dataset #problem: inception_v3 crops dataset differently than all the others args.net = args.ensemble[0] #test, might introduce problem print("args .net when loading: ", args.net) train_loader, val_loader, test_loader, class_list = return_dataset(args) print("Loading in ") # 2. generator G_list = [] #use a list of models F_list = [] #use a list of predictors, one for each classifier in args.ensemble for classifier in args.ensemble: print("classifier: ", classifier) if classifier == 'resnet50': G = ResBase50() inc = 2048 elif classifier == 'resnet101': G = ResBase101() inc = 2048 elif classifier == "alexnet": G = AlexNetBase() inc = 4096 elif classifier == "vgg": G = VGGBase() inc = 4096 elif classifier == "inception_v3": G = models.inception_v3(pretrained=True) inc = 1000 elif classifier == "googlenet": G = models.googlenet(pretrained = True) inc = 1000 elif classifier == "densenet": G = models.densenet161(pretrained = True) inc = 1000 elif classifier == "resnext": G = models.resnext50_32x4d(pretrained = True) inc = 1000 elif classifier == "squeezenet": G = models.squeezenet1_0(pretrained = True) inc = 1000 else: raise ValueError('Model cannot be recognized.') G.cuda() G.train() # 3. classifier F = Predictor(num_class=len(class_list), inc=inc, temp=args.T) F.cuda() F.train() # 4. load pre-trained model G.load_state_dict(torch.load(os.path.join(args.checkpath, "G_net_{}_loss_{}.pth".format(classifier, args.loss)))) F.load_state_dict(torch.load(os.path.join(args.checkpath, "F_net_{}_loss_{}.pth".format(classifier, args.loss)))) G_list.append(G) F_list.append(F) # 5. testing print("evaluating accuracy") acc_test = eval_ensemble(args,test_loader, G_list, F_list, class_list,alphas) print('Testing accuracy: {:.3f}\n'.format(acc_test)) return acc_test