Exemple #1
0
def main():
    args.checkpoint = './checkpoints/mnist/' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # folder to save figures
    args.plotfolder = './checkpoints/mnist/' + args.arch + '/plotter'
    if not os.path.isdir(args.plotfolder):
        mkdir_p(args.plotfolder)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    print('==> Preparing data..')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    trainset = MNIST(root='../../data', train=True, download=True, transform=transform,
                     train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                     includes_all_train_class=args.includes_all_train_class)

    testset = MNIST(root='../../data', train=False, download=True, transform=transform,
                    train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                    includes_all_train_class=args.includes_all_train_class)

    # data loader
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=4)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=4)

    print('==> Building model..')
    net = Network(backbone=args.arch, num_classes=args.train_class_num,embed_dim=args.embed_dim)
    fea_dim = net.classifier.in_features
    net = net.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion_softamx = nn.CrossEntropyLoss()
    criterion_centerloss = CenterLoss(num_classes=args.train_class_num, feat_dim=fea_dim).to(device)
    optimizer_softmax = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    optimizer_centerloss = torch.optim.SGD(criterion_centerloss.parameters(), lr=args.center_lr, momentum=0.9,
                                           weight_decay=5e-4)

    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            criterion_centerloss.load_state_dict(checkpoint['centerloss'])
            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'Total Loss','Softmax Loss', 'Center Loss', 'train Acc.'])


    if not args.evaluate:
        scheduler = lr_scheduler.StepLR(optimizer_softmax, step_size=20, gamma=0.1)
        for epoch in range(start_epoch, start_epoch + args.es):
            print('\nEpoch: %d   Learning rate: %f' % (epoch + 1, optimizer_softmax.param_groups[0]['lr']))
            train_loss, softmax_loss, center_loss, train_acc = train(net, trainloader, optimizer_softmax,
                                                                     optimizer_centerloss, criterion_softamx,
                                                                     criterion_centerloss, device)
            save_model(net, criterion_centerloss, epoch, os.path.join(args.checkpoint, 'last_model.pth'))
            # plot the training data
            if args.plot:
                plot_feature(net,criterion_centerloss, trainloader, device, args.plotfolder, epoch=epoch,
                         plot_class_num=args.train_class_num,maximum=args.plot_max, plot_quality=args.plot_quality)

            logger.append([epoch + 1, train_loss, softmax_loss, center_loss, train_acc])
            scheduler.step()
            test(net, testloader, device)



    if args.plot:
        plot_feature(net, criterion_centerloss, testloader, device, args.plotfolder, epoch="test",
                     plot_class_num=args.train_class_num+1, maximum=args.plot_max, plot_quality=args.plot_quality)
    logger.close()
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # checkpoint
    args.checkpoint = './checkpoints/mnist/' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # folder to save figures
    args.plotfolder = './checkpoints/mnist/' + args.arch + '/plotter'
    if not os.path.isdir(args.plotfolder):
        mkdir_p(args.plotfolder)

    # Data
    print('==> Preparing data..')
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    trainset = MNIST(root='../../data',
                     train=True,
                     download=True,
                     transform=transform,
                     train_class_num=args.train_class_num,
                     test_class_num=args.test_class_num,
                     includes_all_train_class=args.includes_all_train_class)
    testset = MNIST(root='../../data',
                    train=False,
                    download=True,
                    transform=transform,
                    train_class_num=args.train_class_num,
                    test_class_num=args.test_class_num,
                    includes_all_train_class=args.includes_all_train_class)
    # data loader
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.bs,
                                              shuffle=True,
                                              num_workers=4)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.bs,
                                             shuffle=False,
                                             num_workers=4)

    # Model
    net = Network(backbone=args.arch,
                  num_classes=args.train_class_num,
                  embed_dim=args.embed_dim)
    fea_dim = net.classifier.in_features
    net = net.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names([
            'Epoch', 'Learning Rate', 'Train Loss', 'Train Acc.', 'Test Loss',
            'Test Acc.'
        ])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    # test(0, net, trainloader, testloader, criterion, device)
    epoch = 0
    if not args.evaluate:
        for epoch in range(start_epoch, args.es):
            print('\nEpoch: %d   Learning rate: %f' %
                  (epoch + 1, optimizer.param_groups[0]['lr']))
            adjust_learning_rate(optimizer, epoch, args.lr, step=20)
            train_loss, train_acc = train(net, trainloader, optimizer,
                                          criterion, device)
            save_model(net, None, epoch,
                       os.path.join(args.checkpoint, 'last_model.pth'))
            test_loss, test_acc = 0, 0
            #
            logger.append([
                epoch + 1, optimizer.param_groups[0]['lr'], train_loss,
                train_acc, test_loss, test_acc
            ])
            plot_feature(net,
                         trainloader,
                         device,
                         args.plotfolder,
                         epoch=epoch,
                         plot_class_num=args.train_class_num,
                         maximum=args.plot_max,
                         plot_quality=args.plot_quality)
            test(epoch, net, trainloader, testloader, criterion, device)

    test(99999, net, trainloader, testloader, criterion, device)
    plot_feature(net,
                 testloader,
                 device,
                 args.plotfolder,
                 epoch="test",
                 plot_class_num=args.train_class_num + 1,
                 maximum=args.plot_max,
                 plot_quality=args.plot_quality)
    logger.close()