Exemple #1
0
def train(args,weights=None):
    if os.path.exists(args.checkpath) == False:
        os.mkdir(args.checkpath)

    # 1. get dataset
    train_loader, val_loader, test_loader, class_list = return_dataset(args)

    # 2. generator
    if args.net == 'resnet50':
        G = ResBase50()
        inc = 2048
    elif args.net == 'resnet101':
        G = ResBase101()
        inc = 2048        
    elif args.net == "alexnet":
        G = AlexNetBase()
        inc = 4096
    elif args.net == "vgg":
        G = VGGBase()
        inc = 4096
    elif args.net == "inception_v3":
        G = models.inception_v3(pretrained=True) 
        inc = 1000
    elif args.net == "googlenet":
        G = models.googlenet(pretrained = True)
        inc = 1000
    elif args.net == "densenet":
        G = models.densenet161(pretrained = True)
        inc = 1000
    elif args.net == "resnext":
        G = models.resnext50_32x4d(pretrained = True)
        inc = 1000
    elif args.net == "squeezenet":
        G = models.squeezenet1_0(pretrained = True)
        inc = 1000
    else:
        raise ValueError('Model cannot be recognized.')

    params = []
    for key, value in dict(G.named_parameters()).items():
        if value.requires_grad:
            if 'classifier' not in key:
                params += [{'params': [value], 'lr': args.multi,
                            'weight_decay': 0.0005}]
            else:
                params += [{'params': [value], 'lr': args.multi * 10,
                            'weight_decay': 0.0005}]
    G.cuda()
    G.train()

    # 3. classifier
    F = Predictor(num_class=len(class_list), inc=inc, temp=args.T)
    weights_init(F)
    F.cuda()
    F.train()  

    # 4. optimizer
    optimizer_g = optim.SGD(params, momentum=0.9,
                            weight_decay=0.0005, nesterov=True)
    optimizer_f = optim.SGD(list(F.parameters()), lr=1.0, momentum=0.9,
                            weight_decay=0.0005, nesterov=True) 
    optimizer_g.zero_grad()
    optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    # 5. training
    data_iter_train = iter(train_loader)
    len_train = len(train_loader)
    best_acc = 0
    for step in range(args.steps):
        # update optimizer and lr
        optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train == 0:
            data_iter_train = iter(train_loader)

        # forwarding
        data = next(data_iter_train)        
        im_data = data[0].cuda()
        gt_label = data[1].cuda()

        feature = G(im_data)
        if args.net == 'inception_v3': #its not a tensor output but some 'inceptionOutput' object
          feature = feature.logits #get the tensor object
        if args.loss=='CrossEntropy': 
            #call with weights if present 
            loss = crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
            #although the weights might be defaulting to none
        elif args.loss=='FocalLoss':
            loss = focal_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        elif args.loss=='ASoftmaxLoss':
            loss = asoftmax_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        elif args.loss=='SmoothCrossEntropy':
            loss = smooth_crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        else:
            print('To add new loss')         
        loss.backward()

        # backpropagation
        optimizer_g.step()
        optimizer_f.step()       
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        G.zero_grad()
        F.zero_grad()

        if step%args.log_interval==0:
            log_train = 'Train iter: {} lr{} Loss Classification: {:.6f}\n'.format(step, lr, loss.data)
            print(log_train)
        if step and step%args.save_interval==0:
            # evaluate and save
            acc_val = eval(val_loader, G, F, class_list)     
            G.train()  
            F.train()  
            if args.save_check and acc_val >= best_acc:
                best_acc = acc_val
                print('saving model')
                print('best_acc: '+str(best_acc) + '  acc_val: '+str(acc_val))
                torch.save(G.state_dict(), os.path.join(args.checkpath,
                                        "G_net_{}_loss_{}.pth".format(args.net, args.loss)))
                torch.save(F.state_dict(), os.path.join(args.checkpath,
                                        "F_net_{}_loss_{}.pth".format(args.net, args.loss)))
    if (weights is not None):
      print("computing error rate")
      error_rate = eval_adaboost_error_rate(train_loader, G, F, class_list, weights)
      model_importance = torch.log((1-error_rate)/error_rate)/2
      #now update the weights
      print("updating weights")
      update_weights_adaboost(train_loader, G, F, class_list, weights, model_importance)
      return error_rate, model_importance
Exemple #2
0
else:
    raise ValueError('Model cannot be recognized.')

if "resnet" in args.net:
    F1 = Predictor_deep_latent(num_class=len(class_list), inc=inc)
else:
    F1 = Predictor_latent(num_class=len(class_list), inc=inc, temp=args.T)
G = torch.nn.DataParallel(G).cuda()
F1 = torch.nn.DataParallel(F1).cuda()

G_dict = os.path.join(
    args.checkpath,
    "G_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source,
                                           args.target, args.steps))
pretrained_dict = torch.load(G_dict)
model_dict = G.state_dict()
model_dict.update(pretrained_dict)
G.load_state_dict(model_dict)

F_dict = os.path.join(
    args.checkpath,
    "F1_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source,
                                            args.target, args.steps))
pretrained_dict = torch.load(F_dict)
model_dict = F1.state_dict()
model_dict.update(pretrained_dict)
F1.load_state_dict(model_dict)


def test(base, classifier, loader):
    base.eval()