Ejemplo n.º 1
0
def main_stage1():
    print(f"\nStart Stage-1 training ...\n")
    # for  initializing backbone, two branches, and centroids.
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Model
    print('==> Building model..')
    net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim,
                 distance=args.distance, scaled=args.scaled)
    criterion = DFPLoss(alpha=args.alpha, beta=args.beta)
    optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4)

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.stage1_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Softmax Loss', 'Within Loss', 'Between Loss', 'Train Acc.'])

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage1_es):
            adjust_learning_rate(optimizer, epoch, args.stage1_lr, step=15)
            print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage1_train(net, trainloader, optimizer, criterion, device)
            save_model(net, epoch, os.path.join(args.checkpoint,'stage_1_last_model.pth'))
            logger.append([epoch + 1, train_out["train_loss"], train_out["cls_loss"], train_out["dis_loss_within"],
                           train_out["dis_loss_between"], train_out["accuracy"]])
            if args.plot:
                plot_feature(net, trainloader, device, args.plotfolder1, epoch=epoch,
                             plot_class_num=args.train_class_num, maximum=args.plot_max,
                             plot_quality=args.plot_quality,normalized=args.plot_normalized)
    if args.plot:
        # plot the test set
        plot_feature(net, testloader, device, args.plotfolder1, epoch="test",
                     plot_class_num=args.train_class_num + 1, maximum=args.plot_max,
                     plot_quality=args.plot_quality, normalized=args.plot_normalized)

    # calculating distances for last epoch
    distance_results = plot_distance(net, trainloader, device, args)

    logger.close()
    print(f"\nFinish Stage-1 training...\n")
    print("===> Evaluating ...")
    stage1_test(net, testloader, device)

    return {"net": net,
            "distance": distance_results
            }
def main_stage2(net, mid_energy):
    print("Starting stage-2 fine-tuning ...")
    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Train Acc.'])

    # after resume
    criterion = DFPLoss(temperature=args.temperature)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.stage1_lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    if not args.evaluate:
        for epoch in range(start_epoch, args.stage1_es):
            adjust_learning_rate(optimizer,
                                 epoch,
                                 args.stage1_lr,
                                 factor=args.stage1_lr_factor,
                                 step=args.stage1_lr_step)
            print('\nStage_1 Epoch: %d | Learning rate: %f ' %
                  (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage1_train(net, trainloader, optimizer, criterion,
                                     device)
            save_model(net, epoch,
                       os.path.join(args.checkpoint, 'stage_1_last_model.pth'))
            logger.append(
                [epoch + 1, train_out["train_loss"], train_out["accuracy"]])
            if args.plot:
                plot_feature(net,
                             args,
                             trainloader,
                             device,
                             args.plotfolder,
                             epoch=epoch,
                             plot_class_num=args.train_class_num,
                             plot_quality=args.plot_quality)
                plot_feature(net,
                             args,
                             testloader,
                             device,
                             args.plotfolder,
                             epoch="test" + str(epoch),
                             plot_class_num=args.train_class_num + 1,
                             plot_quality=args.plot_quality,
                             testmode=True)
        logger.close()
        print(f"\nFinish Stage-1 training...\n")
def main_stage1():
    print(f"\nStart Stage-1 training ...\n")
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    print('==> Building model..')
    net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim, p=args.p)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion = DFPLoss(temperature=args.temperature)
    optimizer = torch.optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4)

    if args.stage1_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Train Acc.'])

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage1_es):
            adjust_learning_rate(optimizer, epoch, args.stage1_lr,
                                 factor=args.stage1_lr_factor, step=args.stage1_lr_step)
            print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage1_train(net, trainloader, optimizer, criterion, device)
            save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_1_last_model.pth'))
            logger.append([epoch + 1, train_out["train_loss"], train_out["accuracy"]])
            if args.plot:
                plot_feature(net, args, trainloader, device, args.plotfolder, epoch=epoch,
                             plot_class_num=args.train_class_num, plot_quality=args.plot_quality)
                plot_feature(net, args, testloader, device, args.plotfolder, epoch="test" + str(epoch),
                             plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True)
        logger.close()
        print(f"\nFinish Stage-1 training...\n")

    print("===> Evaluating stage-1 ...")
    stage_test(net, testloader, device)
    mid_dict = stage_valmixup(net, trainloader, device)
    print("===> stage1 energy based classification")
    stage_evaluate(net, testloader, mid_dict["mid_unknown"].item(), mid_dict["mid_known"].item(), feature="energy")
    print("===> stage1 softmax based classification")
    stage_evaluate(net, testloader, 0., 1., feature="normweight_fea2cen")
    return {
        "net": net.state_dict(),
        "mid_known": mid_dict["mid_known"],
        "mid_unknown": mid_dict["mid_unknown"]
    }
Ejemplo n.º 4
0
def main():
    print(f"------------Running on {device}------------")
    print('==> Building model..')
    net = NetBuilder(backbone=args.arch, embed_dim=args.embed_dim)
    net = net.to(device)
    center = calcuate_center(net, trainloader, device)
    net._init_centroid(center)
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'Train Loss'])

    criterion = DSVDDLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr, step=20)
        print('\nStage_1 Epoch: %d | Learning rate: %f ' %
              (epoch + 1, optimizer.param_groups[0]['lr']))
        train_loss = train(net, trainloader, optimizer, criterion, device)
        save_model(net, epoch, os.path.join(args.checkpoint, 'last_model.pth'))
        logger.append([epoch + 1, train_loss])
        if args.plot:
            # plot training set
            plot_feature(net,
                         args,
                         trainloader,
                         device,
                         args.plotfolder,
                         epoch=epoch,
                         plot_quality=150)
            # plot testing set
            plot_feature(net,
                         args,
                         testloader,
                         device,
                         args.plotfolder,
                         epoch='test_' + str(epoch),
                         plot_quality=150)
Ejemplo n.º 5
0
def main_stage1():
    print(f"\nStart Stage-1 training ...\n")
    # for  initializing backbone, two branches, and centroids.
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Model
    print('==> Building model..')
    net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim)

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.stage1_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Train Acc.'])

    # after resume
    criterion = DFPLoss(scaling=args.scaling)
    optimizer = optim.Adam(net.parameters(), lr=args.stage1_lr)

    for epoch in range(start_epoch, args.stage1_es):
        adjust_learning_rate(optimizer, epoch, args.stage1_lr, factor=0.2, step=20)
        print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
        train_out = stage1_train(net, trainloader, optimizer, criterion, device)
        save_model(net, epoch, os.path.join(args.checkpoint, 'stage_1_last_model.pth'))
        logger.append([epoch + 1, train_out["train_loss"], train_out["accuracy"]])
        if args.plot:
            plot_feature(net, args, trainloader, device, args.plotfolder, epoch=epoch,
                         plot_class_num=args.train_class_num, plot_quality=args.plot_quality)
            plot_feature(net, args, testloader, device, args.plotfolder, epoch="test" + str(epoch),
                         plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True)

    logger.close()
    print(f"\nFinish Stage-1 training...\n")
    print("===> Evaluating ...")
    stage1_test(net, testloader, device)

    return {
        "net": net
    }
Ejemplo n.º 6
0
def main_stage1():
    print(f"\nStart Stage-1 training ...\n")
    # for  initializing backbone, two branches, and centroids.
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Model
    print('==> Building model..')
    net = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim,
                 distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight)
    # embed_dim = net.feat_dim if not args.embed_dim else args.embed_dim
    # criterion_cls = nn.CrossEntropyLoss()
    criterion_dis = DFPLoss(beta=args.beta, sigma=args.sigma)
    optimizer = optim.SGD(net.parameters(), lr=args.stage1_lr, momentum=0.9, weight_decay=5e-4)

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.stage1_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage1_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage1.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss',
                          'Within Loss', 'Between Loss', 'Cen2cen Loss', 'Train Acc.'])

    for epoch in range(start_epoch, start_epoch + args.stage1_es):
        print('\nStage_1 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
        adjust_learning_rate(optimizer, epoch, args.stage1_lr, step=15)
        train_out = stage1_train(net, trainloader, optimizer, criterion_dis, device)
        save_model(net, epoch, os.path.join(args.checkpoint,'stage_1_last_model.pth'))
        # ['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss',
        # 'Within Loss', 'Between Loss','Cen2cen loss', 'Train Acc.']
        logger.append([epoch + 1, train_out["train_loss"], 0.0,
                       train_out["dis_loss_total"], train_out["dis_loss_within"],
                       train_out["dis_loss_between"], train_out["dis_loss_cen2cen"], train_out["accuracy"]])
        if args.plot:
            plot_feature(net, trainloader, device, args.plotfolder, epoch=epoch,
                         plot_class_num=args.train_class_num, maximum=args.plot_max,plot_quality=args.plot_quality)
    logger.close()
    print(f"\nFinish Stage-1 training...\n")
    return net
Ejemplo n.º 7
0
def main_stage2(stage1_dict):
    net1 = stage1_dict["net"]
    thresholds = stage1_dict["distance"]["thresholds"]

    print(f"\n===> Start Stage-2 training...\n")
    start_epoch = 0
    print('==> Building model..')
    # net2 = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim,
    #              distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight,thresholds=thresholds)
    net2 = DFPNet(backbone=args.arch, num_classes=args.train_class_num, embed_dim=args.embed_dim,
                  distance=args.distance, scaled=args.scaled, cosine_weight=args.cosine_weight, thresholds=thresholds)
    net2 = net2.to(device)

    criterion_dis = DFPLossGeneral(beta=args.beta, sigma=args.sigma,gamma=args.gamma)
    optimizer = optim.SGD(net2.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4)

    if not args.evaluate:
        init_stage2_model(net1, net2)
    if device == 'cuda':
        net2 = torch.nn.DataParallel(net2)
        cudnn.benchmark = True

    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage2_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage1_resume)
            net2.load_state_dict(checkpoint['net'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Within Loss', 'Between Loss', 'Within-Gen Loss', 'Between-Gen Loss',
                          'Random loss', 'Train Acc.'])

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage2_es):
            adjust_learning_rate(optimizer, epoch, args.stage2_lr, step=20)
            print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage2_train(net2, trainloader, optimizer, criterion_dis, device)
            save_model(net2, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth'))
            # ['Epoch', 'Train Loss', 'Softmax Loss', 'Distance Loss',
            # 'Within Loss', 'Between Loss','Cen2cen loss', 'Train Acc.']
            logger.append([epoch + 1, train_out["dis_loss_total"], train_out["dis_loss_within"],
                           train_out["dis_loss_between"],train_out["dis_loss_within_gen"],
                           train_out["dis_loss_between_gen"], train_out["dis_loss_cen2cen"], train_out["accuracy"]])
            if args.plot:
                plot_feature(net2, trainloader, device, args.plotfolder2, epoch=epoch,
                             plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality)
    if args.plot:
        # plot the test set
        plot_feature(net2, testloader, device, args.plotfolder2, epoch="test",
                     plot_class_num=args.train_class_num + 1, maximum=args.plot_max, plot_quality=args.plot_quality)

    # calculating distances for last epoch
    # distance_results = plot_distance(net2, trainloader, device, args)

    logger.close()
    print(f"\nFinish Stage-2 training...\n")
    print("===> Evaluating ...")
    stage1_test(net2, testloader, device)
    return net2
Ejemplo n.º 8
0
def main_stage2(stage1_dict):
    net1 = stage1_dict['net']
    thresholds = stage1_dict['distance']['thresholds']
    estimator = stage1_dict['estimator']
    print(f"\n===> Start Stage-2 training...\n")
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    print('==> Building model..')
    net2 = DFPNet(backbone=args.arch,
                  num_classes=args.train_class_num,
                  embed_dim=args.embed_dim,
                  distance=args.distance,
                  similarity=args.similarity,
                  scaled=args.scaled,
                  thresholds=thresholds,
                  norm_centroid=args.norm_centroid,
                  amplifier=args.amplifier,
                  estimator=estimator)
    net2 = net2.to(device)
    if not args.evaluate and not os.path.isdir(args.stage2_resume):
        init_stage2_model(net1, net2)

    if device == 'cuda':
        net2 = torch.nn.DataParallel(net2)
        cudnn.benchmark = True

    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage2_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage2_resume)
            net2.load_state_dict(checkpoint['net'])
            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'))
        logger.set_names([
            'Epoch', 'Train Loss', 'Similarity Loss', 'Distance in',
            'Distance out', 'Generate within', 'Generate 2origin', 'Train Acc.'
        ])

    # after resume
    criterion = DFPLoss2(alpha=args.alpha, beta=args.beta, theta=args.theta)
    optimizer = optim.SGD(net2.parameters(),
                          lr=args.stage1_lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage2_es):
            print('\nStage_2 Epoch: %d   Learning rate: %f' %
                  (epoch + 1, optimizer.param_groups[0]['lr']))
            # Here, I didn't set optimizers respectively, just for simplicity. Performance did not vary a lot.
            adjust_learning_rate(optimizer, epoch, args.stage2_lr, step=10)
            train_out = stage2_train(net2, trainloader, optimizer, criterion,
                                     device)
            save_model(net2, epoch,
                       os.path.join(args.checkpoint, 'stage_2_last_model.pth'))
            logger.append([
                epoch + 1, train_out["train_loss"],
                train_out["loss_similarity"], train_out["distance_in"],
                train_out["distance_out"], train_out["generate_within"],
                train_out["generate_2orign"], train_out["accuracy"]
            ])
            if args.plot:
                plot_feature(net2,
                             args,
                             trainloader,
                             device,
                             args.plotfolder2,
                             epoch=epoch,
                             plot_class_num=args.train_class_num,
                             maximum=args.plot_max,
                             plot_quality=args.plot_quality,
                             norm_centroid=args.norm_centroid,
                             thresholds=thresholds)
                plot_feature(net2,
                             args,
                             testloader,
                             device,
                             args.plotfolder2,
                             epoch="test_" + str(epoch),
                             plot_class_num=args.train_class_num + 1,
                             maximum=args.plot_max,
                             plot_quality=args.plot_quality,
                             norm_centroid=args.norm_centroid,
                             thresholds=thresholds,
                             testmode=True)
        if args.plot:
            # plot the test set
            plot_feature(net2,
                         args,
                         testloader,
                         device,
                         args.plotfolder2,
                         epoch="test",
                         plot_class_num=args.train_class_num + 1,
                         maximum=args.plot_max,
                         plot_quality=args.plot_quality,
                         norm_centroid=args.norm_centroid,
                         thresholds=thresholds,
                         testmode=True)
        print(f"\nFinish Stage-2 training...\n")

    logger.close()

    # test2(net2, testloader, device)
    return net2
def main_stage2(net, mid_known, mid_unknown):
    print("Starting stage-2 fine-tuning ...")
    start_epoch = 0
    criterion = DFPNormLoss(mid_known=1.3 * mid_known,
                            mid_unknown=0.7 * mid_unknown,
                            alpha=args.alpha,
                            temperature=args.temperature,
                            feature='energy')
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.stage2_lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage2_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage2_resume)
            net.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'))
        logger.set_names([
            'Epoch', 'Train Loss', 'Class Loss', 'Energy Loss', 'Energy Known',
            'Energy Unknown', 'Train Acc.'
        ])

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage2_es):
            adjust_learning_rate(optimizer,
                                 epoch,
                                 args.stage2_lr,
                                 factor=args.stage2_lr_factor,
                                 step=args.stage2_lr_step)
            print('\nStage_2 Epoch: %d | Learning rate: %f ' %
                  (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage2_train(net, trainloader, optimizer, criterion,
                                     device)
            save_model(net, optimizer, epoch,
                       os.path.join(args.checkpoint, 'stage_2_last_model.pth'))
            logger.append([
                epoch + 1, train_out["train_loss"],
                train_out["loss_classification"], train_out["loss_energy"],
                train_out["loss_energy_known"],
                train_out["loss_energy_unknown"], train_out["accuracy"]
            ])
            if args.plot:
                plot_feature(net,
                             args,
                             trainloader,
                             device,
                             args.plotfolder,
                             epoch="stage2_" + str(epoch),
                             plot_class_num=args.train_class_num,
                             plot_quality=args.plot_quality)
                plot_feature(net,
                             args,
                             testloader,
                             device,
                             args.plotfolder,
                             epoch="stage2_test" + str(epoch),
                             plot_class_num=args.train_class_num + 1,
                             plot_quality=args.plot_quality,
                             testmode=True)
        logger.close()
        print(f"\nFinish Stage-2 training...\n")

        print("===> Evaluating stage-2 ...")
        stage_test(net, testloader, device, name="stage2_test_doublebar")
        stage_valmixup(net, trainloader, device, name="stage2_mixup_result")