def train(args,weights=None): if os.path.exists(args.checkpath) == False: os.mkdir(args.checkpath) # 1. get dataset train_loader, val_loader, test_loader, class_list = return_dataset(args) # 2. generator if args.net == 'resnet50': G = ResBase50() inc = 2048 elif args.net == 'resnet101': G = ResBase101() inc = 2048 elif args.net == "alexnet": G = AlexNetBase() inc = 4096 elif args.net == "vgg": G = VGGBase() inc = 4096 elif args.net == "inception_v3": G = models.inception_v3(pretrained=True) inc = 1000 elif args.net == "googlenet": G = models.googlenet(pretrained = True) inc = 1000 elif args.net == "densenet": G = models.densenet161(pretrained = True) inc = 1000 elif args.net == "resnext": G = models.resnext50_32x4d(pretrained = True) inc = 1000 elif args.net == "squeezenet": G = models.squeezenet1_0(pretrained = True) inc = 1000 else: raise ValueError('Model cannot be recognized.') params = [] for key, value in dict(G.named_parameters()).items(): if value.requires_grad: if 'classifier' not in key: params += [{'params': [value], 'lr': args.multi, 'weight_decay': 0.0005}] else: params += [{'params': [value], 'lr': args.multi * 10, 'weight_decay': 0.0005}] G.cuda() G.train() # 3. classifier F = Predictor(num_class=len(class_list), inc=inc, temp=args.T) weights_init(F) F.cuda() F.train() # 4. optimizer optimizer_g = optim.SGD(params, momentum=0.9, weight_decay=0.0005, nesterov=True) optimizer_f = optim.SGD(list(F.parameters()), lr=1.0, momentum=0.9, weight_decay=0.0005, nesterov=True) optimizer_g.zero_grad() optimizer_f.zero_grad() param_lr_g = [] for param_group in optimizer_g.param_groups: param_lr_g.append(param_group["lr"]) param_lr_f = [] for param_group in optimizer_f.param_groups: param_lr_f.append(param_group["lr"]) # 5. training data_iter_train = iter(train_loader) len_train = len(train_loader) best_acc = 0 for step in range(args.steps): # update optimizer and lr optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step, init_lr=args.lr) optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step, init_lr=args.lr) lr = optimizer_f.param_groups[0]['lr'] if step % len_train == 0: data_iter_train = iter(train_loader) # forwarding data = next(data_iter_train) im_data = data[0].cuda() gt_label = data[1].cuda() feature = G(im_data) if args.net == 'inception_v3': #its not a tensor output but some 'inceptionOutput' object feature = feature.logits #get the tensor object if args.loss=='CrossEntropy': #call with weights if present loss = crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train]) #although the weights might be defaulting to none elif args.loss=='FocalLoss': loss = focal_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train]) elif args.loss=='ASoftmaxLoss': loss = asoftmax_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train]) elif args.loss=='SmoothCrossEntropy': loss = smooth_crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train]) else: print('To add new loss') loss.backward() # backpropagation optimizer_g.step() optimizer_f.step() optimizer_g.zero_grad() optimizer_f.zero_grad() G.zero_grad() F.zero_grad() if step%args.log_interval==0: log_train = 'Train iter: {} lr{} Loss Classification: {:.6f}\n'.format(step, lr, loss.data) print(log_train) if step and step%args.save_interval==0: # evaluate and save acc_val = eval(val_loader, G, F, class_list) G.train() F.train() if args.save_check and acc_val >= best_acc: best_acc = acc_val print('saving model') print('best_acc: '+str(best_acc) + ' acc_val: '+str(acc_val)) torch.save(G.state_dict(), os.path.join(args.checkpath, "G_net_{}_loss_{}.pth".format(args.net, args.loss))) torch.save(F.state_dict(), os.path.join(args.checkpath, "F_net_{}_loss_{}.pth".format(args.net, args.loss))) if (weights is not None): print("computing error rate") error_rate = eval_adaboost_error_rate(train_loader, G, F, class_list, weights) model_importance = torch.log((1-error_rate)/error_rate)/2 #now update the weights print("updating weights") update_weights_adaboost(train_loader, G, F, class_list, weights, model_importance) return error_rate, model_importance
else: raise ValueError('Model cannot be recognized.') if "resnet" in args.net: F1 = Predictor_deep_latent(num_class=len(class_list), inc=inc) else: F1 = Predictor_latent(num_class=len(class_list), inc=inc, temp=args.T) G = torch.nn.DataParallel(G).cuda() F1 = torch.nn.DataParallel(F1).cuda() G_dict = os.path.join( args.checkpath, "G_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source, args.target, args.steps)) pretrained_dict = torch.load(G_dict) model_dict = G.state_dict() model_dict.update(pretrained_dict) G.load_state_dict(model_dict) F_dict = os.path.join( args.checkpath, "F1_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source, args.target, args.steps)) pretrained_dict = torch.load(F_dict) model_dict = F1.state_dict() model_dict.update(pretrained_dict) F1.load_state_dict(model_dict) def test(base, classifier, loader): base.eval()