def main(): print(device) print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, scaled=args.scaled) # # embed_dim = net.feat_dim if not args.embed_dim else args.embed_dim # criterion_cls = nn.CrossEntropyLoss() # criterion_dis = DFPLoss(beta=args.beta) # optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) print("=> checkpoint loaded!") else: print("=> no checkpoint found at '{}'".format(args.resume)) else: print("Resume is required") plot_distance(net, trainloader, device, args)
def main_stage1(): print(f"\nStart Stage-1 training ...\n") start_epoch = 0 # start from epoch 0 or last checkpoint epoch print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, p=args.p) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True criterion = DFPLoss(temperature=args.temperature) optimizer = torch.optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt')) logger.set_names(['Epoch', 'Train Loss', 'Train Acc.']) if not args.evaluate: for epoch in range(start_epoch, args.stage1_es): adjust_learning_rate(optimizer, epoch, args.stage1_lr, factor=args.stage1_lr_factor, step=args.stage1_lr_step) print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage1_train(net, trainloader, optimizer, criterion, device) save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_1_last_model.pth')) logger.append([epoch + 1, train_out["train_loss"], train_out["accuracy"]]) if args.plot: plot_feature(net, args, trainloader, device, args.plotfolder, epoch=epoch, plot_class_num=args.train_class_num, plot_quality=args.plot_quality) plot_feature(net, args, testloader, device, args.plotfolder, epoch="test" + str(epoch), plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True) logger.close() print(f"\nFinish Stage-1 training...\n") print("===> Evaluating stage-1 ...") stage_test(net, testloader, device) mid_dict = stage_valmixup(net, trainloader, device) print("===> stage1 energy based classification") stage_evaluate(net, testloader, mid_dict["mid_unknown"].item(), mid_dict["mid_known"].item(), feature="energy") print("===> stage1 softmax based classification") stage_evaluate(net, testloader, 0., 1., feature="normweight_fea2cen") return { "net": net.state_dict(), "mid_known": mid_dict["mid_known"], "mid_unknown": mid_dict["mid_unknown"] }
def main_stage1(): print(f"\nStart Stage-1 training ...\n") # for initializing backbone, two branches, and centroids. start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Model print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, scaled=args.scaled) criterion = DFPLoss(alpha=args.alpha, beta=args.beta) optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt')) logger.set_names(['Epoch', 'Train Loss', 'Softmax Loss', 'Within Loss', 'Between Loss', 'Train Acc.']) if not args.evaluate: for epoch in range(start_epoch, args.stage1_es): adjust_learning_rate(optimizer, epoch, args.stage1_lr, step=15) print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage1_train(net, trainloader, optimizer, criterion, device) save_model(net, epoch, os.path.join(args.checkpoint,'stage_1_last_model.pth')) logger.append([epoch + 1, train_out["train_loss"], train_out["cls_loss"], train_out["dis_loss_within"], train_out["dis_loss_between"], train_out["accuracy"]]) if args.plot: plot_feature(net, trainloader, device, args.plotfolder1, epoch=epoch, plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality,normalized=args.plot_normalized) if args.plot: # plot the test set plot_feature(net, testloader, device, args.plotfolder1, epoch="test", plot_class_num=args.train_class_num + 1, maximum=args.plot_max, plot_quality=args.plot_quality, normalized=args.plot_normalized) # calculating distances for last epoch distance_results = plot_distance(net, trainloader, device, args) logger.close() print(f"\nFinish Stage-1 training...\n") print("===> Evaluating ...") stage1_test(net, testloader, device) return {"net": net, "distance": distance_results }
def main(): print(device) print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, embed_reduction=args.embed_reduction) embed_dim = net.feat_dim if not args.embed_dim else args.embed_dim criterion_cls = nn.CrossEntropyLoss() criterion_dis = DFPLoss(num_classes=args.train_class_num, feat_dim=embed_dim, beta=args.beta, distance=args.distance, scaled=args.scaled) net = net.to(device) criterion_dis = criterion_dis.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) criterion_dis = torch.nn.DataParallel(criterion_dis) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) criterion_dis.load_state_dict(checkpoint['criterion']) print("=> checkpoint loaded!") else: print("=> no checkpoint found at '{}'".format(args.resume)) else: print("Resume is required") plot_feature(net, criterion_dis, trainloader, device, args.plotter, epoch=0, plot_class_num=10, maximum=args.plot_max)
def main_stage1(): print(f"\nStart Stage-1 training ...\n") # for initializing backbone, two branches, and centroids. start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Model print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt')) logger.set_names(['Epoch', 'Train Loss', 'Train Acc.']) # after resume criterion = DFPLoss(scaling=args.scaling) optimizer = optim.Adam(net.parameters(), lr=args.stage1_lr) for epoch in range(start_epoch, args.stage1_es): adjust_learning_rate(optimizer, epoch, args.stage1_lr, factor=0.2, step=20) print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage1_train(net, trainloader, optimizer, criterion, device) save_model(net, epoch, os.path.join(args.checkpoint, 'stage_1_last_model.pth')) logger.append([epoch + 1, train_out["train_loss"], train_out["accuracy"]]) if args.plot: plot_feature(net, args, trainloader, device, args.plotfolder, epoch=epoch, plot_class_num=args.train_class_num, plot_quality=args.plot_quality) plot_feature(net, args, testloader, device, args.plotfolder, epoch="test" + str(epoch), plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True) logger.close() print(f"\nFinish Stage-1 training...\n") print("===> Evaluating ...") stage1_test(net, testloader, device) return { "net": net }
def main_stage1(): print(f"\nStart Stage-1 training ...\n") # for initializing backbone, two branches, and centroids. start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Model print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight) # embed_dim = net.feat_dim if not args.embed_dim else args.embed_dim # criterion_cls = nn.CrossEntropyLoss() criterion_dis = DFPLoss(beta=args.beta, sigma=args.sigma) optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt')) logger.set_names(['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss', 'Within Loss', 'Between Loss', 'Cen2cen Loss', 'Train Acc.']) for epoch in range(start_epoch, start_epoch + args.stage1_es): print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) adjust_learning_rate(optimizer, epoch, args.stage1_lr, step=15) train_out = stage1_train(net, trainloader, optimizer, criterion_dis, device) save_model(net, epoch, os.path.join(args.checkpoint,'stage_1_last_model.pth')) # ['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss', # 'Within Loss', 'Between Loss','Cen2cen loss', 'Train Acc.'] logger.append([epoch + 1, train_out["train_loss"], 0.0, train_out["dis_loss_total"], train_out["dis_loss_within"], train_out["dis_loss_between"], train_out["dis_loss_cen2cen"], train_out["accuracy"]]) if args.plot: plot_feature(net, trainloader, device, args.plotfolder, epoch=epoch, plot_class_num=args.train_class_num, maximum=args.plot_max,plot_quality=args.plot_quality) logger.close() print(f"\nFinish Stage-1 training...\n") return net
def main_stage2(stage1_dict): net1 = stage1_dict["net"] thresholds = stage1_dict["distance"]["thresholds"] print(f"\n===> Start Stage-2 training...\n") start_epoch = 0 print('==> Building model..') # net2 = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, # distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight,thresholds=thresholds) net2 = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight, thresholds=thresholds) net2 = net2.to(device) criterion_dis = DFPLossGeneral(beta=args.beta, sigma=args.sigma,gamma=args.gamma) optimizer = optim.SGD(net2.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4) if not args.evaluate: init_stage2_model(net1, net2) if device == 'cuda': net2 = torch.nn.DataParallel(net2) cudnn.benchmark = True if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net2.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names(['Epoch', 'Train Loss', 'Within Loss', 'Between Loss', 'Within-Gen Loss', 'Between-Gen Loss', 'Random loss', 'Train Acc.']) if not args.evaluate: for epoch in range(start_epoch, args.stage2_es): adjust_learning_rate(optimizer, epoch, args.stage2_lr, step=20) print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage2_train(net2, trainloader, optimizer, criterion_dis, device) save_model(net2, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth')) # ['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss', # 'Within Loss', 'Between Loss','Cen2cen loss', 'Train Acc.'] logger.append([epoch + 1, train_out["dis_loss_total"], train_out["dis_loss_within"], train_out["dis_loss_between"],train_out["dis_loss_within_gen"], train_out["dis_loss_between_gen"], train_out["dis_loss_cen2cen"], train_out["accuracy"]]) if args.plot: plot_feature(net2, trainloader, device, args.plotfolder2, epoch=epoch, plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality) if args.plot: # plot the test set plot_feature(net2, testloader, device, args.plotfolder2, epoch="test", plot_class_num=args.train_class_num + 1, maximum=args.plot_max, plot_quality=args.plot_quality) # calculating distances for last epoch # distance_results = plot_distance(net2, trainloader, device, args) logger.close() print(f"\nFinish Stage-2 training...\n") print("===> Evaluating ...") stage1_test(net2, testloader, device) return net2
def main_stage2(stage1_dict): net1 = stage1_dict['net'] thresholds = stage1_dict['distance']['thresholds'] estimator = stage1_dict['estimator'] print(f"\n===> Start Stage-2 training...\n") start_epoch = 0 # start from epoch 0 or last checkpoint epoch print('==> Building model..') net2 = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, similarity=args.similarity, scaled=args.scaled, thresholds=thresholds, norm_centroid=args.norm_centroid, amplifier=args.amplifier, estimator=estimator) net2 = net2.to(device) if not args.evaluate and not os.path.isdir(args.stage2_resume): init_stage2_model(net1, net2) if device == 'cuda': net2 = torch.nn.DataParallel(net2) cudnn.benchmark = True if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage2_resume) net2.load_state_dict(checkpoint['net']) # best_acc = checkpoint['acc'] # print("BEST_ACCURACY: "+str(best_acc)) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names([ 'Epoch', 'Train Loss', 'Similarity Loss', 'Distance in', 'Distance out', 'Generate within', 'Generate 2origin', 'Train Acc.' ]) # after resume criterion = DFPLoss2(alpha=args.alpha, beta=args.beta, theta=args.theta) optimizer = optim.SGD(net2.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) if not args.evaluate: for epoch in range(start_epoch, args.stage2_es): print('\nStage_2 Epoch: %d Learning rate: %f' % (epoch + 1, optimizer.param_groups[0]['lr'])) # Here, I didn't set optimizers respectively, just for simplicity. Performance did not vary a lot. adjust_learning_rate(optimizer, epoch, args.stage2_lr, step=10) train_out = stage2_train(net2, trainloader, optimizer, criterion, device) save_model(net2, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth')) logger.append([ epoch + 1, train_out["train_loss"], train_out["loss_similarity"], train_out["distance_in"], train_out["distance_out"], train_out["generate_within"], train_out["generate_2orign"], train_out["accuracy"] ]) if args.plot: plot_feature(net2, args, trainloader, device, args.plotfolder2, epoch=epoch, plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality, norm_centroid=args.norm_centroid, thresholds=thresholds) plot_feature(net2, args, testloader, device, args.plotfolder2, epoch="test_" + str(epoch), plot_class_num=args.train_class_num + 1, maximum=args.plot_max, plot_quality=args.plot_quality, norm_centroid=args.norm_centroid, thresholds=thresholds, testmode=True) if args.plot: # plot the test set plot_feature(net2, args, testloader, device, args.plotfolder2, epoch="test", plot_class_num=args.train_class_num + 1, maximum=args.plot_max, plot_quality=args.plot_quality, norm_centroid=args.norm_centroid, thresholds=thresholds, testmode=True) print(f"\nFinish Stage-2 training...\n") logger.close() # test2(net2, testloader, device) return net2
def main_stage2(stage1_dict): print('==> Building stage2 model..') start_epoch = 0 # start from epoch 0 or last checkpoint epoch net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, similarity=args.similarity, scaled=args.scaled, norm_centroid=args.norm_centroid, decorrelation=args.decorrelation) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if not args.evaluate and not os.path.isfile(args.stage2_resume): net = stage1_dict['net'] net = net.to(device) thresholds = stage1_dict['distance']['thresholds'] # stat = stage1_dict["stat"] net.module.set_threshold(thresholds.to(device)) if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage2_resume) net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] try: thresholds = checkpoint['net']['thresholds'] except: thresholds = checkpoint['net']['module.thresholds'] net.module.set_threshold(thresholds.to(device)) logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names([ 'Epoch', 'Train Loss', 'Similarity Loss', 'Distance in', 'Distance out', 'Distance Center', 'Train Acc.' ]) if args.evaluate: stage2_test(net, testloader, device) return net # after resume criterion = DFPLoss2(alpha=args.alpha, beta=args.beta, theta=args.theta) optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) for epoch in range(start_epoch, args.stage2_es): print('\nStage_2 Epoch: %d Learning rate: %f' % (epoch + 1, optimizer.param_groups[0]['lr'])) # Here, I didn't set optimizers respectively, just for simplicity. Performance did not vary a lot. adjust_learning_rate(optimizer, epoch, args.stage2_lr, step=20) # if epoch %5 ==0: # distance_results = plot_distance(net, trainloader, device, args) # thresholds = distance_results['thresholds'] # net.module.set_threshold(thresholds.to(device)) train_out = stage2_train(net, trainloader, optimizer, criterion, device) save_model(net, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth')) stage2_test(net, testloader, device) # stat = get_gap_stat(net2, trainloader, device, args) logger.append([ epoch + 1, train_out["train_loss"], train_out["loss_similarity"], train_out["distance_in"], train_out["distance_out"], train_out["distance_center"], train_out["accuracy"] ]) print(f"\nFinish Stage-2 training...\n") logger.close() stage2_test(net, testloader, device) return net
def main_stage1(): print(f"\nStart Stage-1 training ...\n") # for initializing backbone, two branches, and centroids. start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Model print('==> Building model..') net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, distance=args.distance, similarity=args.similarity, scaled=args.scaled, norm_centroid=args.norm_centroid, decorrelation=args.decorrelation) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True if args.stage1_resume: # Load checkpoint. if os.path.isfile(args.stage1_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage1_resume) net.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt')) logger.set_names([ 'Epoch', 'Train Loss', 'Similarity Loss', 'Distance Loss', 'Train Acc.' ]) # after resume criterion = DFPLoss(alpha=args.alpha) optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4) for epoch in range(start_epoch, args.stage1_es): adjust_learning_rate(optimizer, epoch, args.stage1_lr, step=20) print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage1_train(net, trainloader, optimizer, criterion, device) save_model(net, epoch, os.path.join(args.checkpoint, 'stage_1_last_model.pth')) logger.append([ epoch + 1, train_out["train_loss"], train_out["loss_similarity"], train_out["loss_distance"], train_out["accuracy"] ]) # calculating distances for last epoch distance_results = plot_distance(net, trainloader, device, args) # print(f"the distance thresholds are\n {distance_results['thresholds']}\n") # gap_results = plot_gap(net, trainloader, device, args) # stat = get_gap_stat(net, trainloader, device, args) # estimator =CGD_estimator(gap_results) logger.close() print(f"\nFinish Stage-1 training...\n") print("===> Evaluating ...") stage1_test(net, testloader, device) return { "net": net, "distance": distance_results, # "stat": stat }
def main_stage2(stage1_dict): print("Starting stage-2 fine-tuning ...") start_epoch = 0 # get key values from stage1_dict mid_known = stage1_dict["mid_known"] mid_unknown = stage1_dict["mid_unknown"] net_state_dict = stage1_dict["net"] net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, p=args.p) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True optimizer = torch.optim.SGD(net.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4) if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage2_resume) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] mid_known = checkpoint["mid_known"] mid_unknown = checkpoint["mid_unknown"] logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: net.load_state_dict(net_state_dict) logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names([ 'Epoch', 'Train Loss', 'Class Loss', 'Energy Loss', 'Energy Known', 'Energy Unknown', 'Train Acc.' ]) criterion = DFPNormLoss(mid_known=1.3 * mid_known, mid_unknown=0.7 * mid_unknown, alpha=args.alpha, temperature=args.temperature, feature='energy') if not args.evaluate: for epoch in range(start_epoch, args.stage2_es): adjust_learning_rate(optimizer, epoch, args.stage2_lr, factor=args.stage2_lr_factor, step=args.stage2_lr_step) print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage2_train(net, trainloader, optimizer, criterion, device) save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth'), mid_known=mid_known, mid_unknown=mid_unknown) logger.append([ epoch + 1, train_out["train_loss"], train_out["loss_classification"], train_out["loss_energy"], train_out["loss_energy_known"], train_out["loss_energy_unknown"], train_out["accuracy"] ]) if args.plot: plot_feature(net, args, trainloader, device, args.plotfolder, epoch="stage2_" + str(epoch), plot_class_num=args.train_class_num, plot_quality=args.plot_quality) plot_feature(net, args, testloader, device, args.plotfolder, epoch="stage2_test" + str(epoch), plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True) logger.close() print(f"\nFinish Stage-2 training...\n") print("===> Evaluating stage-2 ...") stage_test(net, testloader, device, name="stage2_test_doublebar") stage_valmixup(net, trainloader, device, name="stage2_mixup_result") stage_evaluate(net, testloader, mid_unknown.item(), mid_known.item(), feature="energy")