Ejemplo n.º 1
0
 def __init__(self,
              n_classes,
              backbone='resnet50',
              dilated=True,
              norm_layer=None,
              multi_grid=False,
              multi_dilation=None):
     super(BaseNet, self).__init__()
     self.n_classes = n_classes
     # copying modules from pretrained models
     if backbone == 'resnet50':
         self.backbone = resnet.resnet50(pretrained=True,
                                         dilated=dilated,
                                         norm_layer=norm_layer,
                                         multi_grid=multi_grid,
                                         multi_dilation=multi_dilation)
         self.backbone2 = resnet.resnet50(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     elif backbone == 'resnet101':
         self.backbone = resnet.resnet101(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     elif backbone == 'resnet152':
         self.backbone = resnet.resnet152(pretrained=True,
                                          dilated=dilated,
                                          norm_layer=norm_layer,
                                          multi_grid=multi_grid,
                                          multi_dilation=multi_dilation)
     else:
         raise RuntimeError('unknown backbone: {}'.format(backbone))
     # bilinear upsample options
     self._up_kwargs = up_kwargs
Ejemplo n.º 2
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser(description='Implement image classification on ImageNet datset using pytorch')
    parser.add_argument('--arch', default='bam', type=str, help='Attention Model (bam, cbam)')
    parser.add_argument('--backbone', default='resnet50', type=str, help='backbone classification model (resnet(18, 34, 50, 101, 152)')
    parser.add_argument('--epoch', default=1, type=int, help='start epoch')
    parser.add_argument('--n_epochs', default=350, type=int, help='numeber of total epochs to run')
    parser.add_argument('--batch', default=256, type=int, help='mini batch size (default: 1024)')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--save_directory', default='trained.chkpt', type=str, help='path to latest checkpoint')
    parser.add_argument('--workers', default=0, type=int, help='num_workers')
    parser.add_argument('--resume', default=False, type=bool, help='resume')
    parser.add_argument('--datasets', default='CIFAR100', type=str, help='classification dataset  (CIFAR10, CIFAR100, ImageNet)')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight_decay')
    parser.add_argument('--save', default='trained', type=str, help='trained.chkpt')
    parser.add_argument('--save_multi', default='trained_multi', type=str, help='trained_multi.chkpt')
    parser.add_argument('--evaluate', default=False, type=bool, help='evaluate')
    parser.add_argument('--reduction_ratio', default=16, type=int, help='reduction_ratio')
    parser.add_argument('--dilation_value', default=4, type=int, help='reduction_ratio')
    args = parser.parse_args()
    args.arch = args.arch.lower()
    args.backbone = args.backbone.lower()
    args.datasets = args.datasets.lower()

    if not os.path.isdir('checkpoints'):
        os.mkdir('checkpoints')
    # To-do: Write a code relating to seed.

    # use gpu or multi-gpu or not.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_multi_gpu = torch.cuda.device_count() > 1
    print('[Info] device:{} use_multi_gpu:{}'.format(device, use_multi_gpu))

    if args.datasets == 'cifar10':
        num_classes = 10
    elif args.datasets == 'cifar100':
        num_classes = 100
    elif args.datasets == 'imagenet':
        num_classes = 1000

    # load the data.
    print('[Info] Load the data.')
    train_loader, valid_loader, train_size, valid_size = prepare_dataloaders(args)

    # load the model.
    print('[Info] Load the model.')

    if args.backbone == 'resnet18':
        model = resnet.resnet18(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet34':
        model = resnet.resnet34(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet50':
        model = resnet.resnet50(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet101':
        model = resnet.resnet101(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet152':
        model = resnet.resnet152(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)


    model = model.to(device)
    if use_multi_gpu : model = torch.nn.DataParallel(model)
    print('[Info] Total parameters {} '.format(count_parameters(model)))
    # define loss function.
    criterion = torch.nn.CrossEntropyLoss().to(device)

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume:
        # Load the checkpoint.
        print('[Info] Loading checkpoint.')
        if torch.cuda.device_count() > 1:
            checkpoint = load_checkpoint(args.save)
        else:
            checkpoint = load_checkpoint(args.save)

        backbone = checkpoint['backbone']
        args.epoch = checkpoint['epoch']
        state_dict = checkpoint['state_dict']
        model.load_state_dict(state_dict)
        print('[Info] epoch {} backbone {}'.format(args.epoch, backbone))

    # run evaluate.
    if args.evaluate:
        _ = run_epoch(model, 'valid', [args.epoch, args.epoch], criterion, optimizer, valid_loader, valid_size, device)
        return

    # run train.
    best_acc1 = 0.
    for e in range(args.epoch, args.n_epochs + 1):
        adjust_learning_rate(optimizer, e, args)

        # train for one epoch
        _ = run_epoch(model, 'train', [e, args.n_epochs], criterion, optimizer, train_loader, train_size, device)

        # evaluate on validation set
        with torch.no_grad():
            acc1 = run_epoch(model, 'valid', [e, args.n_epochs], criterion, optimizer, valid_loader, valid_size, device)

        # Save checkpoint.
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint({
            'epoch': e,
            'backbone': args.backbone,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)

        if use_multi_gpu:
            save_checkpoint({
                'epoch': e,
                'backbone': args.backbone,
                'state_dict': model.module.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_multi)

        print('[Info] acc1 {} best@acc1 {}'.format(acc1, best_acc1))