Beispiel #1
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    args.save_dir = osp.join(args.save_dir, str(args.nExemplars) + '-shot')
    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))
    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")
    print('Initializing image data manager')
    dm = DataManager(args, use_gpu)
    trainloader, testloader = dm.return_dataloaders()
    model = Model(args)
    criterion = CrossEntropyLoss()
    optimizer = init_optimizer(args.optim, model.parameters(), args.lr,
                               args.weight_decay)
    if use_gpu:
        model = model.cuda()
    start_time = time.time()
    train_time = 0
    best_acc = -np.inf
    best_epoch = 0
    print("==> Start training")
    for epoch in range(args.max_epoch):
        learning_rate = adjust_learning_rate(optimizer, epoch, args.LUT_lr)
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, learning_rate,
              use_gpu)
        train_time += round(time.time() - start_train_time)
        if epoch == 0 or epoch > (args.stepsize[0] -
                                  1) or (epoch + 1) % 10 == 0:
            acc = test(model, testloader, use_gpu)
            is_best = acc > best_acc
            if is_best:
                best_acc = acc
                best_epoch = epoch + 1
            save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'acc': acc,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
            print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".
                  format(best_acc, best_epoch))
    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    print("==========\nArgs:{}\n==========".format(args))
Beispiel #2
0
    print('============ validation on the val set ============', file=F_txt)
    if epoch_item % 10 == 0 or epoch_item > 60:
        prec1, _ = test(val_dataloader, model, criterion, epoch_item,
                        best_prec1, F_txt)

        # record the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        # save the checkpoint
        if is_best:
            save_checkpoint(
                {
                    'epoch_index': epoch_item,
                    'arch': opt.basemodel,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(outf, 'model_best.pth.tar'))

        if epoch_item % 10 == 0:
            filename = os.path.join(outf, 'epoch_%d.pth.tar' % epoch_item)
            save_checkpoint(
                {
                    'epoch_index': epoch_item,
                    'arch': opt.basemodel,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, filename)
Beispiel #3
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()

    model_name = '{}-{}shot-{}kernel-{}group'.format(args.dataset,
                                                     args.nExemplars,
                                                     args.kernel, args.groups)

    if args.suffix is not None:
        model_name = model_name + '-{}'.format(args.suffix)

    save_dir = os.path.join(args.save_dir, model_name)
    os.makedirs(save_dir, exist_ok=True)

    sys.stdout = Logger(osp.join(save_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print('Initializing image data manager')
    dm = DataManager(args, use_gpu)
    trainloader, testloader = dm.return_dataloaders()

    model = Model(num_classes=args.num_classes,
                  groups=args.groups,
                  kernel=args.kernel)
    #model = nn.DataParallel(model)
    criterion = CrossEntropyLoss()
    optimizer = init_optimizer(args.optim, model.parameters(), args.lr,
                               args.weight_decay)

    if use_gpu:
        model = model.cuda()

    start_time = time.time()
    train_time = 0
    best_acc = -np.inf
    best_epoch = 0
    print("==> Start training")

    warmup_epoch = 5
    scheduler = warmup_scheduler(base_lr=args.lr,
                                 iter_per_epoch=len(trainloader),
                                 max_epoch=args.max_epoch + warmup_epoch,
                                 multi_step=[],
                                 warmup_epoch=warmup_epoch)

    for epoch in range(args.max_epoch + warmup_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, scheduler,
              use_gpu)
        train_time += round(time.time() - start_train_time)

        if epoch == 0 or epoch > (args.stepsize[0] -
                                  1) or (epoch + 1) % 10 == 0:
            acc = test(model, testloader, use_gpu)
            is_best = acc > best_acc

            if is_best:
                best_acc = acc
                best_epoch = epoch + 1

            if is_best or epoch > args.max_epoch:
                save_checkpoint(
                    {
                        'state_dict': model.state_dict(),
                        'acc': acc,
                        'epoch': epoch,
                    }, is_best,
                    osp.join(save_dir,
                             'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

            print("==> Test 5-way Best accuracy {:.2%}, achieved at epoch {}".
                  format(best_acc, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    print("==========\nArgs:{}\n==========".format(args))