示例#1
0
def main():
    global args
    global best_prec1, best_prec5

    args = parser.parse_args()

    # simple args
    debug = args.debug
    if debug: cprint('=> WARN: Debug Mode', 'yellow')

    dataset = args.dataset
    if dataset == 'cub':
        num_classes = 200
    elif dataset == 'dog':
        num_classes = 120
    elif dataset == 'car':
        num_classes = 196

    base_size = 512
    pool_size = 14 if base_size == 512 else 7
    workers = 0 if debug else 4
    batch_size = 2 if debug else 32
    if base_size == 512 and \
        args.arch == '152':
        batch_size = 128
    drop_ratio = 0.1
    if dataset == 'cub':
        lr_drop_epoch_list = [31, 51, 71]
    else:
        lr_drop_epoch_list = [51, 71, 91]
    epochs = 100
    eval_freq = 1
    gpu_ids = [0] if debug else [0]
    crop_size = 448
    log_name = "CNLtrainlog.txt"
    # args for the nl and cgnl block
    arch = args.arch

    # warmup setting
    WARMUP_LRS = [args.lr * (drop_ratio**len(lr_drop_epoch_list)), args.lr]
    WARMUP_EPOCHS = 10

    # data loader
    if dataset == 'cub':
        data_root = '/input/data/cub/CUB_200_2011'
        imgs_fold = os.path.join(data_root, 'images')
        train_ann_file = os.path.join(data_root, 'cub_train.list')
        valid_ann_file = os.path.join(data_root, 'cub_val.list')
    elif dataset == 'dog':
        data_root = '/input/data/Standford_dog'
        imgs_fold = os.path.join(data_root, 'images')
        train_ann_file = os.path.join(data_root, 'dog_train.list')
        valid_ann_file = os.path.join(data_root, 'dog_val.list')
    elif dataset == 'car':
        data_root = '/input/data/Standford_car'
        imgs_fold = os.path.join(data_root, 'images')
        train_ann_file = os.path.join(data_root, 'car_train.list')
        valid_ann_file = os.path.join(data_root, 'car_val.list')
    else:
        raise NameError("WARN: The dataset '{}' is not supported yet.")

    train_dataset = dataloader.ImgLoader(
        root=imgs_fold,
        ann_file=train_ann_file,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(size=crop_size, scale=(0.08, 1.25)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))

    val_dataset = dataloader.ImgLoader(
        root=imgs_fold,
        ann_file=valid_ann_file,
        transform=transforms.Compose([
            transforms.Resize(base_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=workers,
                                             pin_memory=True)

    # build model
    model = model_hub(arch, pretrained=True, pool_size=pool_size)

    # change the fc layer
    model._modules['fc_m'] = torch.nn.Linear(in_features=2048,
                                             out_features=num_classes)
    torch.nn.init.kaiming_normal_(model._modules['fc_m'].weight,
                                  mode='fan_out',
                                  nonlinearity='relu')
    print(model)
    # parallel
    model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    # optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=1e-4)

    # cudnn
    cudnn.benchmark = True

    # warmup
    if args.warmup:
        epochs += WARMUP_EPOCHS
        lr_drop_epoch_list = list(np.array(lr_drop_epoch_list) + WARMUP_EPOCHS)
        cprint(
            '=> WARN: warmup is used in the first {} epochs'.format(
                WARMUP_EPOCHS), 'yellow')

    # valid
    if args.valid:
        cprint('=> WARN: Validation Mode', 'yellow')
        print('start validation ...')
        checkpoint_fold = args.checkpoints
        checkpoint_best = os.path.join(checkpoint_fold, 'model_best.pth.tar')
        print('=> loading state_dict from {}'.format(checkpoint_best))
        model.load_state_dict(torch.load(checkpoint_best)['state_dict'])
        prec1, prec5 = validate(val_loader, model, criterion)
        print(' * Final Accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(
            prec1, prec5))
        exit(0)

    # train
    print('start training ...')

    f = open(log_name, 'w')
    f.write(log_name)
    f.close()

    for epoch in range(0, epochs):
        current_lr = adjust_learning_rate(optimizer, drop_ratio, epoch,
                                          lr_drop_epoch_list, WARMUP_EPOCHS,
                                          WARMUP_LRS)
        # train one epoch
        train(train_loader, model, criterion, optimizer, epoch, epochs,
              current_lr)

        checkpoint_name = '{}-r-{}-w-CNL5-block.pth.tar'.format(dataset, arch)

        if (epoch + 1) % eval_freq == 0:
            prec1, prec5 = validate(val_loader, model, criterion, log_name)
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            best_prec5 = max(prec5, best_prec5)
            print(' * Best accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(
                best_prec1, best_prec5))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=checkpoint_name)
    def __init__(self, options, path):
        """Prepare the network, criterion, solver, and data.

        Args:
            options, dict: Hyperparameters.
        """
        print('Prepare the network and data.')
        self._options = options
        self._path = path
        # Network.
        self._net = torch.nn.DataParallel(BCNN()).cuda()
        # Load the model from disk.
        self._net.load_state_dict(torch.load(self._path['model']))
        print(self._net)
        # Criterion.
        self._criterion = torch.nn.CrossEntropyLoss().cuda()
        # Solver.
        self._solver = torch.optim.SGD(
            [
                self._net.module.w1, self._net.module.w2, self._net.module.w3,
                self._net.module.w4
            ] + list(self._net.module.fc.parameters()) +
            list(self._net.module.features.parameters()),
            lr=self._options['base_lr'],
            momentum=0.9,
            weight_decay=self._options['weight_decay'])
        self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self._solver,
            mode='max',
            factor=0.1,
            patience=7,
            verbose=True,
            threshold=1e-4)

        train_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(size=448),  # Let smaller edge match
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomCrop(size=448),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))
        ])
        test_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(size=448),
            torchvision.transforms.CenterCrop(size=448),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))
        ])
        imgs_fold = 'data/mit/Images/'
        train_ann_file = 'data/mit/TrainImagesnew.txt'
        valid_ann_file = 'data/mit/TestImagesnew.txt'

        train_dataset = dataloader.ImgLoader(root=imgs_fold,
                                             ann_file=train_ann_file,
                                             transform=train_transforms)

        val_dataset = dataloader.ImgLoader(root=imgs_fold,
                                           ann_file=valid_ann_file,
                                           transform=test_transforms)

        self._train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self._options['batch_size'],
            shuffle=True,
            num_workers=16,
            pin_memory=True)
        self._test_loader = torch.utils.data.DataLoader(val_dataset,
                                                        batch_size=16,
                                                        shuffle=False,
                                                        num_workers=16,
                                                        pin_memory=True)
示例#3
0
def main():
    global args
    global best_prec1, best_prec5
    global checkpoint_fold, checkpoint_best


    sys.stdout = Logger(osp.join(checkpoint_fold, 'log_train.txt'))

    writer = SummaryWriter('save_model/{}/{}/{}'.format(args.dataset, args.arch, args.checkpoints))
    # simple args
    debug = args.debug
    if debug: cprint('=> WARN: Debug Mode', 'yellow')

    dataset = args.dataset


    if dataset == 'cub':
        num_classes = 200
        base_size = 512
        #batch_size = 60
        batch_size = 48
        crop_size = 448
        pool_size = 14
        args.warmup = True
        pretrain = args.pretrained
        args.backbone = 'resnet'
        args.arch = '50'

        epochs = 100
        eval_freq = 5
        args.lr = 0.01
        lr_drop_epoch_list = [31, 61, 81]
    elif dataset == 'cifar10':
        num_classes = 10
        base_size = 32
        batch_size = 128
        crop_size = 32
        args.warmup = True
        pretrain = args.pretrained
        #args.backbone = 'resnet'
        epochs = 300
        eval_freq = 5
        args.lr = 0.1
        lr_drop_epoch_list = [150, 250]

        if args.backbone == 'preresnet':
            pool_size = 8
        else:
            pool_size = 4

    elif dataset == 'cifar100':
        num_classes = 100
        base_size = 32
        batch_size = 128
        crop_size = 32
        args.warmup = True
        pretrain = args.pretrained
        epochs = 300
        eval_freq = 5
        args.lr = 0.1
        lr_drop_epoch_list = [150, 250]

        if args.backbone == 'preresnet':
            pool_size = 8
        else:
            pool_size = 4

    
    else: ##imagenet
        num_classes = 1000
        base_size = 256
        batch_size = 100
        crop_size = 224
        pool_size = 7
        args.warmup = True
        pretrain = args.pretrained
        args.backbone = 'resnet'

        epochs = 100
        eval_freq = 5
        args.lr = 0.01
        lr_drop_epoch_list = [31, 61, 81]
    
    workers = 4

    if debug:
        batch_size = 2
        workers = 0


    if base_size == 512 and \
        args.arch == '152':
        batch_size = 128
    drop_ratio = 0.1
    gpu_ids = [0,1]

    # args for the nl and cgnl block
    arch = args.arch
    nl_type  = args.nl_type # 'cgnl' | 'cgnlx' | 'nl'
    nl_nums  = args.nl_nums # 1: stage res4

    # warmup setting
    WARMUP_LRS = [args.lr * (drop_ratio**len(lr_drop_epoch_list)), args.lr]
    WARMUP_EPOCHS = 10

    # data loader
    if dataset == 'cub':
        data_root = os.path.join(args.data_dir, 'cub')
        imgs_fold = os.path.join(data_root, 'images')
        train_ann_file = os.path.join(data_root, 'cub_train.list')
        valid_ann_file = os.path.join(data_root, 'cub_val.list')
    elif dataset == 'imagenet':
        data_root = '/home/sheqi/lei/dataset/imagenet'
        imgs_fold = os.path.join(data_root)
        train_ann_file = os.path.join(data_root, 'imagenet_train.list')
        valid_ann_file = os.path.join(data_root, 'imagenet_val.list')
    elif dataset == 'cifar10':
        print("cifar10")
    elif dataset == 'cifar100':
        print("cifar100")
    else:
        raise NameError("WARN: The dataset '{}' is not supported yet.")

    if dataset == 'cub' or dataset == 'imagenet':
        train_dataset = dataloader.ImgLoader(
                root = imgs_fold,
                ann_file = train_ann_file,
                transform = transforms.Compose([
                    transforms.RandomResizedCrop(
                        size=crop_size, scale=(0.08, 1.25)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
                    ]))

        val_dataset = dataloader.ImgLoader(
                root = imgs_fold,
                ann_file = valid_ann_file,
                transform = transforms.Compose([
                    transforms.Resize(base_size),
                    transforms.CenterCrop(crop_size),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
                    ]))

        train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size = batch_size,
                shuffle = True,
                num_workers = workers,
                pin_memory = True)

        val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size = batch_size,
                shuffle = False,
                num_workers = workers,
                pin_memory = True)

    elif dataset == 'cifar10':
        train_transform = transforms.Compose([
                    transforms.RandomCrop(crop_size, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465],
                        [0.2023, 0.1994, 0.2010])
                    ])
        val_transform = transforms.Compose([
                    #transforms.Resize(base_size),
                    #transforms.CenterCrop(crop_size),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465],
                        [0.2023, 0.1994, 0.2010])
                    ])
        trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=False , transform = train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
        testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=False , transform = val_transform)
        val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    elif dataset == 'cifar100':
        train_transform = transforms.Compose([
                    transforms.RandomCrop(crop_size, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465],
                        [0.2023, 0.1994, 0.2010])
                    ])
        val_transform = transforms.Compose([
                    #transforms.Resize(base_size),
                    #transforms.CenterCrop(crop_size),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465],
                        [0.2023, 0.1994, 0.2010])
                    ])
        trainset = torchvision.datasets.CIFAR100(root=args.data_dir, train=True, download=False , transform = train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
        testset = torchvision.datasets.CIFAR100(root=args.data_dir, train=False, download=False , transform = val_transform)
        val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    # build model

#####################################################
    if args.backbone == 'resnet':
        model = resnet_snl.model_hub(arch,
                                 pretrained=pretrain,
                                 nl_type=nl_type,
                                 nl_nums=nl_nums,
                                 stage_num=args.stage_nums,
                                 pool_size=pool_size, div=args.div, isrelu=args.relu)
    elif args.backbone == 'preresnet':
        model = preresnet_snl.model_hub(arch,
                                 pretrained=pretrain,
                                 nl_type=nl_type,
                                 nl_nums=nl_nums,
                                 stage_num=args.stage_nums,
                                 pool_size=pool_size, 
                                 div=args.div,
                                 nl_layer = args.nl_layer,
                                 relu = args.relu)
    else:
        raise KeyError("Unsupported nonlocal type: {}".format(nl_type))
####################################################


    # change the first conv for CIFAR
    if dataset == 'cifar10' or dataset == 'cifar100':
         model._modules['conv1'] = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
         model._modules['maxpool'] = torch.nn.Sequential()

    # change the fc layer
    if dataset != 'imagenet':
        model._modules['fc'] = torch.nn.Linear(in_features=2048,
                                           out_features=num_classes)
        torch.nn.init.kaiming_normal_(model._modules['fc'].weight,
                                  mode='fan_out', nonlinearity='relu')
    print(model)

    # parallel
    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda()
    else:
        model = model.cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # optimizer
    optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=0.9,
            weight_decay=1e-4)

    # cudnn
    cudnn.benchmark = True

    # warmup
    if args.warmup:
        epochs += WARMUP_EPOCHS
        lr_drop_epoch_list = list(
                np.array(lr_drop_epoch_list) + WARMUP_EPOCHS)
        cprint('=> WARN: warmup is used in the first {} epochs'.format(
            WARMUP_EPOCHS), 'yellow')


    start_epoch = 0
    if args.isresume:
        print('loading checkpoint {}'.format(resume_path))
        checkpoint = torch.load(resume_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        print("epoches: {}, best_prec1: {}". format(start_epoch, best_prec1 ))


    # valid
    if args.valid:
        cprint('=> WARN: Validation Mode', 'yellow')
        print('start validation ...')
        print('=> loading state_dict from {}'.format(args.check_path))
        model.load_state_dict(
                torch.load(args.check_path)['state_dict'], strict=True)
        prec1, prec5 = validate(val_loader, model, criterion)
        print(' * Final Accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(prec1, prec5))
        exit(0)

    # train
    print('start training ...')
    for epoch in range(start_epoch, epochs):
        current_lr = adjust_learning_rate(optimizer, drop_ratio, epoch, lr_drop_epoch_list,
                                          WARMUP_EPOCHS, WARMUP_LRS)
        # train one epoch
        cur_loss = train(train_loader, model, criterion, optimizer, epoch, epochs, current_lr)
        writer.add_scalar("Train Loss", cur_loss, epoch + 1)

        if nl_nums > 0:
            checkpoint_name = '{}-{}-r-{}-w-{}{}-block.pth.tar'.format(epoch, dataset, arch, nl_nums, nl_type)
        else:
            checkpoint_name = '{}-r-{}-{}-base.pth.tar'.format(dataset, arch, epoch)

        checkpoint_name = os.path.join(checkpoint_fold, checkpoint_name)

        if (epoch + 1) % eval_freq == 0:
            prec1, prec5 = validate(val_loader, model, criterion)
##########################################################
            writer.add_scalar("Top1", prec1, epoch + 1)
            writer.add_scalar("Top5", prec5, epoch + 1)
##########################################################
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            best_prec5 = max(prec5, best_prec5)
            print(' * Best accuracy: Prec@1 {:.3f}, Prec@5 {:.3f}'.format(best_prec1, best_prec5))
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, filename=checkpoint_name)