Exemple #1
0
def main_stage2(net, vae, mid_energy):
    print("Starting stage-2 fine-tuning ...")
    start_epoch = 0
    criterion = DFPNormLoss(mid_known=mid_energy["mid_known"], mid_unknown=mid_energy["mid_unknown"],
                            alpha=args.alpha, temperature=args.temperature)
    optimizer = torch.optim.SGD(net.parameters(), lr=args.stage2_lr, momentum=0.9, weight_decay=5e-4)
    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage2_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage2_resume)
            net.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'))
        logger.set_names(['Epoch', 'Train Loss', 'Class Loss', 'Energy Loss','Energy Known', 'Energy Unknown', 'Train Acc.'])

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage2_es):
            adjust_learning_rate(optimizer, epoch, args.stage2_lr,
                                 factor=args.stage2_lr_factor, step=args.stage2_lr_step)
            print('\nStage_2 Epoch: %d | Learning rate: %f ' % (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage2_train(net, trainloader, vae, optimizer, criterion, device)
            save_model(net, optimizer, epoch, os.path.join(args.checkpoint, 'stage_2_last_model.pth'))
            logger.append([epoch + 1, train_out["train_loss"], train_out["loss_classification"],
                           train_out["loss_energy"], train_out["loss_energy_known"],
                           train_out["loss_energy_unknown"], train_out["accuracy"]])
            if args.plot:
                plot_feature(net, args, trainloader, device, args.plotfolder, epoch="stage2_"+str(epoch),
                             plot_class_num=args.train_class_num, plot_quality=args.plot_quality)
                plot_feature(net, args, testloader, device, args.plotfolder, epoch="stage2_test" + str(epoch),
                             plot_class_num=args.train_class_num + 1, plot_quality=args.plot_quality, testmode=True)
        logger.close()
        print(f"\nFinish Stage-2 training...\n")

        print("===> Evaluating stage-2 ...")
        stage2_test(net, testloader, trainloader, device)
Exemple #2
0
def main_stage2(stage1_dict):
    print("Starting stage-2 fine-tuning ...")
    start_epoch = 0

    # get key values from stage1_dict
    mid_known = stage1_dict["mid_known"]
    mid_unknown = stage1_dict["mid_unknown"]
    net_state_dict = stage1_dict["net"]

    net = DFPNet(backbone=args.arch,
                 num_classes=args.train_class_num,
                 embed_dim=args.embed_dim,
                 p=args.p)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.stage2_lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    if args.stage2_resume:
        # Load checkpoint.
        if os.path.isfile(args.stage2_resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.stage2_resume)
            net.load_state_dict(checkpoint['net'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            mid_known = checkpoint["mid_known"]
            mid_unknown = checkpoint["mid_unknown"]
            logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        net.load_state_dict(net_state_dict)
        logger = Logger(os.path.join(args.checkpoint, 'log_stage2.txt'))
        logger.set_names([
            'Epoch', 'Train Loss', 'Class Loss', 'Energy Loss', 'Energy Known',
            'Energy Unknown', 'Train Acc.'
        ])

    criterion = DFPNormLoss(mid_known=1.3 * mid_known,
                            mid_unknown=0.7 * mid_unknown,
                            alpha=args.alpha,
                            temperature=args.temperature,
                            feature='energy')

    if not args.evaluate:
        for epoch in range(start_epoch, args.stage2_es):
            adjust_learning_rate(optimizer,
                                 epoch,
                                 args.stage2_lr,
                                 factor=args.stage2_lr_factor,
                                 step=args.stage2_lr_step)
            print('\nStage_2 Epoch: %d | Learning rate: %f ' %
                  (epoch + 1, optimizer.param_groups[0]['lr']))
            train_out = stage2_train(net, trainloader, optimizer, criterion,
                                     device)
            save_model(net,
                       optimizer,
                       epoch,
                       os.path.join(args.checkpoint, 'stage_2_last_model.pth'),
                       mid_known=mid_known,
                       mid_unknown=mid_unknown)
            logger.append([
                epoch + 1, train_out["train_loss"],
                train_out["loss_classification"], train_out["loss_energy"],
                train_out["loss_energy_known"],
                train_out["loss_energy_unknown"], train_out["accuracy"]
            ])
            if args.plot:
                plot_feature(net,
                             args,
                             trainloader,
                             device,
                             args.plotfolder,
                             epoch="stage2_" + str(epoch),
                             plot_class_num=args.train_class_num,
                             plot_quality=args.plot_quality)
                plot_feature(net,
                             args,
                             testloader,
                             device,
                             args.plotfolder,
                             epoch="stage2_test" + str(epoch),
                             plot_class_num=args.train_class_num + 1,
                             plot_quality=args.plot_quality,
                             testmode=True)
        logger.close()
        print(f"\nFinish Stage-2 training...\n")

    print("===> Evaluating stage-2 ...")
    stage_test(net, testloader, device, name="stage2_test_doublebar")
    stage_valmixup(net, trainloader, device, name="stage2_mixup_result")
    stage_evaluate(net,
                   testloader,
                   mid_unknown.item(),
                   mid_known.item(),
                   feature="energy")