Пример #1
0
 def __init__(self,
              n_classes,
              backbone='resnet50',
              dilated=True,
              norm_layer=None,
              multi_grid=False,
              multi_dilation=None):
     super(BaseNet, self).__init__()
     self.n_classes = n_classes
     # copying modules from pretrained models
     if backbone == 'resnet50':
         self.backbone = resnet.resnet50(pretrained=True,
                                         dilated=dilated,
                                         norm_layer=norm_layer,
                                         multi_grid=multi_grid,
                                         multi_dilation=multi_dilation)
         self.backbone2 = resnet.resnet50(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     elif backbone == 'resnet101':
         self.backbone = resnet.resnet101(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     elif backbone == 'resnet152':
         self.backbone = resnet.resnet152(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     else:
         raise RuntimeError('unknown backbone: {}'.format(backbone))
     # bilinear upsample options
     self._up_kwargs = up_kwargs
Пример #2
0
def train(gpu, args):

    rank = gpu

    if args.gpus > 1:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.gpus,
                                rank=rank)

    torch.manual_seed(0)
    model = resnet50()
    optimizer = torch.optim.SGD(model.parameters(), 1e-1)

    epoch = 0
    if args.load_checkpoint is not None:
        checkpoint = torch.load(args.load_checkpoint)
        epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    batch_size = 250

    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    criterion = nn.CrossEntropyLoss().cuda(gpu)

    if args.gpus > 1:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    transform = torch.nn.Sequential(transforms.Pad(224),
                                    transforms.CenterCrop((224, 224)))

    train_dataset = get_ilsvrc2012_train_dataset(transform)
    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=args.gpus,
                                       rank=gpu,
                                       shuffle=True)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              num_workers=8,
                              pin_memory=True,
                              sampler=train_sampler)

    total_step = len(train_loader)
    for epoch in range(epoch, args.epochs):
        start = datetime.now()

        for i, (images, labels) in enumerate(train_loader):

            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 10 == 0 and gpu == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1, args.epochs, i + 1, total_step, loss.item()))

        if gpu == 0:
            print('This Epoch complete in: ' + str(datetime.now() - start))
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, args.dump_checkpoint + '_{}'.format(epoch))

    if args.gpus > 1:
        dist.destroy_process_group()
Пример #3
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser(description='Implement image classification on ImageNet datset using pytorch')
    parser.add_argument('--arch', default='bam', type=str, help='Attention Model (bam, cbam)')
    parser.add_argument('--backbone', default='resnet50', type=str, help='backbone classification model (resnet(18, 34, 50, 101, 152)')
    parser.add_argument('--epoch', default=1, type=int, help='start epoch')
    parser.add_argument('--n_epochs', default=350, type=int, help='numeber of total epochs to run')
    parser.add_argument('--batch', default=256, type=int, help='mini batch size (default: 1024)')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--save_directory', default='trained.chkpt', type=str, help='path to latest checkpoint')
    parser.add_argument('--workers', default=0, type=int, help='num_workers')
    parser.add_argument('--resume', default=False, type=bool, help='resume')
    parser.add_argument('--datasets', default='CIFAR100', type=str, help='classification dataset  (CIFAR10, CIFAR100, ImageNet)')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight_decay')
    parser.add_argument('--save', default='trained', type=str, help='trained.chkpt')
    parser.add_argument('--save_multi', default='trained_multi', type=str, help='trained_multi.chkpt')
    parser.add_argument('--evaluate', default=False, type=bool, help='evaluate')
    parser.add_argument('--reduction_ratio', default=16, type=int, help='reduction_ratio')
    parser.add_argument('--dilation_value', default=4, type=int, help='reduction_ratio')
    args = parser.parse_args()
    args.arch = args.arch.lower()
    args.backbone = args.backbone.lower()
    args.datasets = args.datasets.lower()

    if not os.path.isdir('checkpoints'):
        os.mkdir('checkpoints')
    # To-do: Write a code relating to seed.

    # use gpu or multi-gpu or not.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_multi_gpu = torch.cuda.device_count() > 1
    print('[Info] device:{} use_multi_gpu:{}'.format(device, use_multi_gpu))

    if args.datasets == 'cifar10':
        num_classes = 10
    elif args.datasets == 'cifar100':
        num_classes = 100
    elif args.datasets == 'imagenet':
        num_classes = 1000

    # load the data.
    print('[Info] Load the data.')
    train_loader, valid_loader, train_size, valid_size = prepare_dataloaders(args)

    # load the model.
    print('[Info] Load the model.')

    if args.backbone == 'resnet18':
        model = resnet.resnet18(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet34':
        model = resnet.resnet34(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet50':
        model = resnet.resnet50(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet101':
        model = resnet.resnet101(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet152':
        model = resnet.resnet152(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)


    model = model.to(device)
    if use_multi_gpu : model = torch.nn.DataParallel(model)
    print('[Info] Total parameters {} '.format(count_parameters(model)))
    # define loss function.
    criterion = torch.nn.CrossEntropyLoss().to(device)

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume:
        # Load the checkpoint.
        print('[Info] Loading checkpoint.')
        if torch.cuda.device_count() > 1:
            checkpoint = load_checkpoint(args.save)
        else:
            checkpoint = load_checkpoint(args.save)

        backbone = checkpoint['backbone']
        args.epoch = checkpoint['epoch']
        state_dict = checkpoint['state_dict']
        model.load_state_dict(state_dict)
        print('[Info] epoch {} backbone {}'.format(args.epoch, backbone))

    # run evaluate.
    if args.evaluate:
        _ = run_epoch(model, 'valid', [args.epoch, args.epoch], criterion, optimizer, valid_loader, valid_size, device)
        return

    # run train.
    best_acc1 = 0.
    for e in range(args.epoch, args.n_epochs + 1):
        adjust_learning_rate(optimizer, e, args)

        # train for one epoch
        _ = run_epoch(model, 'train', [e, args.n_epochs], criterion, optimizer, train_loader, train_size, device)

        # evaluate on validation set
        with torch.no_grad():
            acc1 = run_epoch(model, 'valid', [e, args.n_epochs], criterion, optimizer, valid_loader, valid_size, device)

        # Save checkpoint.
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint({
            'epoch': e,
            'backbone': args.backbone,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)

        if use_multi_gpu:
            save_checkpoint({
                'epoch': e,
                'backbone': args.backbone,
                'state_dict': model.module.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_multi)

        print('[Info] acc1 {} best@acc1 {}'.format(acc1, best_acc1))