コード例 #1
0
    # path_data = '~/data'
    # train_dataset = l2l.vision.datasets.MiniImagenet(
    #     root=path_data, mode='train')
    # valid_dataset = l2l.vision.datasets.MiniImagenet(
    #     root=path_data, mode='validation')
    # test_dataset = l2l.vision.datasets.MiniImagenet(
    #     root=path_data, mode='test')

    # Hello, CURE_TSR_OG
    data_transforms = transforms.Compose(
        [transforms.Resize([32, 32]),
         transforms.ToTensor()])  #, utils.l2normalize, utils.standardization])

    lvl0_train_dir = './CURE_TSR_OG/Real_Train/ChallengeFree/'
    lvl5_test_dir = './CURE_TSR_OG/Real_Train/LensBlur-5/'
    curetsr_lvl0 = utils.CURETSRDataset(lvl0_train_dir, data_transforms)
    curetsr_lvl5 = utils.CURETSRDataset(lvl5_test_dir, data_transforms)

    # lvl0_train_dir = './CURE_TSR_Yahan_Shortcut/Real_Train/ChallengeFree/'
    # lvl5_test_dir = './CURE_TSR_Yahan_Shortcut/Real_Train/Snow-5/'
    # curetsr_lvl0 = datasets.ImageFolder(lvl0_train_dir, transform=data_transforms)
    # print("first image, label is ", curetsr_lvl0[0])
    # curetsr_lvl5 = datasets.ImageFolder(lvl5_test_dir, transform=data_transforms)

    meta_curetsr_lvl0 = l2l.data.MetaDataset(curetsr_lvl0)
    meta_curetsr_lvl5 = l2l.data.MetaDataset(curetsr_lvl5)

    train_dataset = meta_curetsr_lvl0
    valid_dataset = meta_curetsr_lvl0
    test_dataset = meta_curetsr_lvl5
コード例 #2
0
def main():
    global args
    args = parser.parse_args()
    traindir = os.path.join(args.data, 'RealChallengeFree/train')
    testdir = os.path.join(args.data, 'RealChallengeFree/Test')
    train_dataset = utils.CURETSRDataset(
        traindir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_dataset = utils.CURETSRDataset(
        testdir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model = models.AutoEncoder()
    model = torch.nn.DataParallel(model).cuda()
    print("=> creating model %s " % model.__class__.__name__)
    criterion = nn.MSELoss().cuda()

    savedir = 'AutoEncoder'
    checkpointdir = os.path.join('./checkpoints', savedir)
    os.makedirs(checkpointdir, exist_ok=True)
    print('log directory: %s' % os.path.join('./logs', savedir))
    print('checkpoints directory: %s' % checkpointdir)
    logger = Logger(os.path.join('./logs/', savedir))
    if args.evaluate:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            os.path.join(checkpointdir, 'model_best.pth.tar'))
        model.load_state_dict(checkpoint['AE_state_dict'], strict=False)
        modelCNN = models.Net()
        modelCNN = torch.nn.DataParallel(modelCNN).cuda()
        checkpoint2 = torch.load('./checkpoints/CNN_iter/model_best.pth.tar')
        modelCNN.load_state_dict(checkpoint2['state_dict'], strict=False)
        evaluate(test_loader, model, modelCNN, criterion)
        return
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    cudnn.benchmark = True

    timestart = time.time()

    if args.finetune:
        print("=> loading checkpoint ")
        checkpoint = torch.load(
            os.path.join(checkpointdir, 'model_best.pth.tar'))
        model.load_state_dict(checkpoint['AE_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer'])

    best_loss = 10e10
    # train_accs = []
    # test_accs = []
    loss_epochs = []

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print('\n*** Start Training *** \n')
        loss_train = train(train_loader, test_loader, model, criterion,
                           optimizer, epoch)
        print(loss_train)
        loss_epochs.append(loss_train)
        is_best = loss_train < best_loss
        print(best_loss)
        best_loss = min(loss_train, best_loss)
        info = {
            'Loss': loss_train
            # 'Testing Accuracy': test_prec1
        }
        # if not debug:
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if is_best:
            best_epoch = epoch + 1
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'AE_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, is_best, checkpointdir)
    generate_plots(range(args.start_epoch, args.epochs), loss_epochs)
    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
    print('Best loss:', best_loss)
コード例 #3
0
def main():
    global args
    args = parser.parse_args()

    debug = 0  # 0: normal mode 1: debug mode

    # Data loading code
    # args.data: path to the dataset
    traindir = os.path.join(args.data, 'RealChallengeFree/train')
    testdir = os.path.join(args.data, 'RealChallengeFree/Test')

    train_dataset = utils.CURETSRDataset(
        traindir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    test_dataset = utils.CURETSRDataset(
        testdir,
        transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(), utils.l2normalize, utils.standardization
        ]))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model = models.Net()
    model = torch.nn.DataParallel(model).cuda()
    print("=> creating model %s " % model.__class__.__name__)

    savedir = 'CNN_iter'
    checkpointdir = os.path.join('./checkpoints', savedir)

    if not debug:
        os.mkdir(checkpointdir)
        print('log directory: %s' % os.path.join('./logs', savedir))
        print('checkpoints directory: %s' % checkpointdir)

    # Set the logger
    if not debug:
        logger = Logger(os.path.join('./logs/', savedir))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                "=> loaded checkpoint '{}' (epoch {}, best_prec1 @ Source {})".
                format(args.resume, checkpoint['epoch'], best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        evaluate(test_loader, model, criterion)
        return

    cudnn.benchmark = True

    timestart = time.time()
    best_prec1 = 0

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        print('\n*** Start Training *** \n')
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        print('\n*** Start Testing *** \n')
        test_loss, test_prec1, _ = evaluate(test_loader, model, criterion)

        info = {'Testing loss': test_loss, 'Testing Accuracy': test_prec1}

        # remember best prec@1 and save checkpoint
        is_best = test_prec1 > best_prec1
        best_prec1 = max(test_prec1, best_prec1)

        if is_best:
            best_epoch = epoch + 1

        if not debug:
            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch + 1)

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'last_prec1': test_prec1,
                        'optimizer': optimizer.state_dict()
                    }, is_best, checkpointdir)

    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
コード例 #4
0
def cure_tsr_tasksets(
        train_ways,
        train_samples,
        test_ways,
        test_samples,
        #root,
        **kwargs):
    """
    Benchmark definition for CURE TSR.
    """
    data_transforms = transforms.Compose(
        [transforms.Resize([32, 32]),
         transforms.ToTensor()])  #, utils.l2normalize, utils.standardization])

    lvl0_train_dir = './CURE_TSR_OG/Real_Train/ChallengeFree/'
    lvl5_test_dir = './CURE_TSR_OG/Real_Train/Snow-5/'
    curetsr_lvl0 = utils.CURETSRDataset(lvl0_train_dir, data_transforms)
    curetsr_lvl5 = utils.CURETSRDataset(lvl5_test_dir, data_transforms)

    # lvl0_train_dir = './CURE_TSR_Yahan_Shortcut/Real_Train/ChallengeFree/'
    # lvl5_test_dir = './CURE_TSR_Yahan_Shortcut/Real_Train/Snow-5/'
    # curetsr_lvl0 = datasets.ImageFolder(lvl0_train_dir, transform=data_transforms)
    # print("first image, label is ", curetsr_lvl0[0])
    # curetsr_lvl5 = datasets.ImageFolder(lvl5_test_dir, transform=data_transforms)

    meta_curetsr_lvl0 = l2l.data.MetaDataset(curetsr_lvl0)
    meta_curetsr_lvl5 = l2l.data.MetaDataset(curetsr_lvl5)

    train_dataset = meta_curetsr_lvl0
    validation_dataset = meta_curetsr_lvl0
    test_dataset = meta_curetsr_lvl5

    classes = list(range(14))  # 14 classes of stop signs
    random.shuffle(classes)
    train_transforms = [
        l2l.data.transforms.FusedNWaysKShots(
            train_dataset,
            n=train_ways,
            k=train_samples,
            filter_labels=classes[:8]),  # first few classes for training
        l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ]
    validation_transforms = [
        l2l.data.transforms.FusedNWaysKShots(
            validation_dataset,
            n=test_ways,
            k=test_samples,
            filter_labels=classes[8:14]),  # last few classes for val / test
        l2l.data.transforms.LoadData(validation_dataset),
        l2l.data.transforms.RemapLabels(validation_dataset),
        l2l.data.transforms.ConsecutiveLabels(validation_dataset),
    ]
    test_transforms = [
        l2l.data.transforms.FusedNWaysKShots(
            test_dataset,
            n=test_ways,
            k=test_samples,
            filter_labels=classes[8:14]),  # last few classes for val / test
        l2l.data.transforms.LoadData(test_dataset),
        l2l.data.transforms.RemapLabels(test_dataset),
        l2l.data.transforms.ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, validation_dataset, test_dataset)
    _transforms = (train_transforms, validation_transforms, test_transforms)
    return _datasets, _transforms