def main_stage2(net, vae, mid_energy): print("Starting stage-2 fine-tuning ...") start_epoch = 0 criterion = DFPNormLoss(mid_known=mid_energy["mid_known"], mid_unknown=mid_energy["mid_unknown"], alpha=args.alpha, temperature=args.temperature) optimizer = torch.optim.SGD(net.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4) if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage2_resume) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names(['Epoch', 'Train Loss', 'Class Loss', 'Energy Loss','Energy Known', 'Energy Unknown', 'Train Acc.']) if not args.evaluate: for epoch in range(start_epoch, args.stage2_es): adjust_learning_rate(optimizer, epoch, args.stage2_lr, factor=args.stage2_lr_factor, step=args.stage2_lr_step) print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage2_train(net, trainloader, vae, optimizer, criterion, device) save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth')) logger.append([epoch + 1, train_out["train_loss"], train_out["loss_classification"], train_out["loss_energy"], train_out["loss_energy_known"], train_out["loss_energy_unknown"], train_out["accuracy"]]) if args.plot: plot_feature(net, args, trainloader, device, args.plotfolder, epoch="stage2_"+str(epoch), plot_class_num=args.train_class_num, plot_quality=args.plot_quality) plot_feature(net, args, testloader, device, args.plotfolder, epoch="stage2_test" + str(epoch), plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True) logger.close() print(f"\nFinish Stage-2 training...\n") print("===> Evaluating stage-2 ...") stage2_test(net, testloader, trainloader, device)
def main_stage2(stage1_dict): print("Starting stage-2 fine-tuning ...") start_epoch = 0 # get key values from stage1_dict mid_known = stage1_dict["mid_known"] mid_unknown = stage1_dict["mid_unknown"] net_state_dict = stage1_dict["net"] net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, p=args.p) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) cudnn.benchmark = True optimizer = torch.optim.SGD(net.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4) if args.stage2_resume: # Load checkpoint. if os.path.isfile(args.stage2_resume): print('==> Resuming from checkpoint..') checkpoint = torch.load(args.stage2_resume) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] mid_known = checkpoint["mid_known"] mid_unknown = checkpoint["mid_unknown"] logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: net.load_state_dict(net_state_dict) logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt')) logger.set_names([ 'Epoch', 'Train Loss', 'Class Loss', 'Energy Loss', 'Energy Known', 'Energy Unknown', 'Train Acc.' ]) criterion = DFPNormLoss(mid_known=1.3 * mid_known, mid_unknown=0.7 * mid_unknown, alpha=args.alpha, temperature=args.temperature, feature='energy') if not args.evaluate: for epoch in range(start_epoch, args.stage2_es): adjust_learning_rate(optimizer, epoch, args.stage2_lr, factor=args.stage2_lr_factor, step=args.stage2_lr_step) print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr'])) train_out = stage2_train(net, trainloader, optimizer, criterion, device) save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth'), mid_known=mid_known, mid_unknown=mid_unknown) logger.append([ epoch + 1, train_out["train_loss"], train_out["loss_classification"], train_out["loss_energy"], train_out["loss_energy_known"], train_out["loss_energy_unknown"], train_out["accuracy"] ]) if args.plot: plot_feature(net, args, trainloader, device, args.plotfolder, epoch="stage2_" + str(epoch), plot_class_num=args.train_class_num, plot_quality=args.plot_quality) plot_feature(net, args, testloader, device, args.plotfolder, epoch="stage2_test" + str(epoch), plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True) logger.close() print(f"\nFinish Stage-2 training...\n") print("===> Evaluating stage-2 ...") stage_test(net, testloader, device, name="stage2_test_doublebar") stage_valmixup(net, trainloader, device, name="stage2_mixup_result") stage_evaluate(net, testloader, mid_unknown.item(), mid_known.item(), feature="energy")