コード例 #1
0
ファイル: main.py プロジェクト: wisemanl12/884project
def test(args):
    # 1. get dataset
    train_loader, val_loader, test_loader, class_list = return_dataset(args)

    # 2. generator
    if args.net == 'resnet50':
        G = ResBase50()
        inc = 2048
    elif args.net == 'resnet101':
        G = ResBase101()
        inc = 2048        
    elif args.net == "alexnet":
        G = AlexNetBase()
        inc = 4096
    elif args.net == "vgg":
        G = VGGBase()
        inc = 4096
    elif args.net == "inception_v3":
        G = models.inception_v3(pretrained=True) 
        inc = 1000
    elif args.net == "googlenet":
        G = models.googlenet(pretrained = True)
        inc = 1000
    elif args.net == "densenet":
        G = models.densenet161(pretrained = True)
        inc = 1000
    elif args.net == "resnext":
        G = models.resnext50_32x4d(pretrained = True)
        inc = 1000
    elif args.net == "squeezenet":
        G = models.squeezenet1_0(pretrained = True)
        inc = 1000
    else:
        raise ValueError('Model cannot be recognized.')
    G.cuda()
    G.train()

    # 3. classifier
    F = Predictor(num_class=len(class_list), inc=inc, temp=args.T)
    F.cuda()
    F.train()  

    # 4. load pre-trained model
    G.load_state_dict(torch.load(os.path.join(args.checkpath,
                                        "G_net_{}_loss_{}.pth".format(args.net, args.loss))))
    F.load_state_dict(torch.load(os.path.join(args.checkpath,
                                        "F_net_{}_loss_{}.pth".format(args.net, args.loss))))
    # 5. testing
    acc_test = eval(test_loader, G, F, class_list)
    print('Testing accuracy: {:.3f}\n'.format(acc_test))    
    return acc_test 
コード例 #2
0
ファイル: main.py プロジェクト: wisemanl12/884project
def test_ensemble(args, alphas=None):
  # 1. get dataset
  #problem: inception_v3 crops dataset differently than all the others
  args.net = args.ensemble[0] #test, might introduce problem
  print("args .net when loading: ", args.net)
  train_loader, val_loader, test_loader, class_list = return_dataset(args)
  print("Loading in ")
  # 2. generator
  G_list = [] #use a list of models
  F_list = [] #use a list of predictors, one for each classifier in args.ensemble

  for classifier in args.ensemble:
    print("classifier: ", classifier)
    if classifier == 'resnet50':
        G = ResBase50()
        inc = 2048
    elif classifier == 'resnet101':
        G = ResBase101()
        inc = 2048        
    elif classifier == "alexnet":
        G = AlexNetBase()
        inc = 4096
    elif classifier == "vgg":
        G = VGGBase()
        inc = 4096
    elif classifier == "inception_v3":
        G = models.inception_v3(pretrained=True) 
        inc = 1000
    elif classifier == "googlenet":
        G = models.googlenet(pretrained = True)
        inc = 1000
    elif classifier == "densenet":
        G = models.densenet161(pretrained = True)
        inc = 1000
    elif classifier == "resnext":
        G = models.resnext50_32x4d(pretrained = True)
        inc = 1000
    elif classifier == "squeezenet":
        G = models.squeezenet1_0(pretrained = True)
        inc = 1000
    else:
        raise ValueError('Model cannot be recognized.')
    G.cuda()
    G.train()

    # 3. classifier
    F = Predictor(num_class=len(class_list), inc=inc, temp=args.T)
    F.cuda()
    F.train()  

    # 4. load pre-trained model
    G.load_state_dict(torch.load(os.path.join(args.checkpath,
                                        "G_net_{}_loss_{}.pth".format(classifier, args.loss))))
    F.load_state_dict(torch.load(os.path.join(args.checkpath,
                                        "F_net_{}_loss_{}.pth".format(classifier, args.loss))))
    G_list.append(G)
    F_list.append(F)
  # 5. testing
  print("evaluating accuracy")
  acc_test = eval_ensemble(args,test_loader, G_list, F_list, class_list,alphas)
  print('Testing accuracy: {:.3f}\n'.format(acc_test))      
  return acc_test