Ejemplo n.º 1
0
def main(args):
    checkpoint = torch.load(args.model_path)
    resnet_backbone = resnet18(num_classes=42).cuda()

    resnet_backbone.load_state_dict(checkpoint['resnet_backbone'])

    transform = transforms.Compose([transforms.ToTensor()])

    my_val_dataset = MyDatasets(args.test_dataset, transform)
    my_val_dataloader = DataLoader(my_val_dataset,
                                   batch_size=5,
                                   shuffle=True,
                                   num_workers=0)

    validate(my_val_dataloader, resnet_backbone)
Ejemplo n.º 2
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser(description='Implement image classification on ImageNet datset using pytorch')
    parser.add_argument('--arch', default='bam', type=str, help='Attention Model (bam, cbam)')
    parser.add_argument('--backbone', default='resnet50', type=str, help='backbone classification model (resnet(18, 34, 50, 101, 152)')
    parser.add_argument('--epoch', default=1, type=int, help='start epoch')
    parser.add_argument('--n_epochs', default=350, type=int, help='numeber of total epochs to run')
    parser.add_argument('--batch', default=256, type=int, help='mini batch size (default: 1024)')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--save_directory', default='trained.chkpt', type=str, help='path to latest checkpoint')
    parser.add_argument('--workers', default=0, type=int, help='num_workers')
    parser.add_argument('--resume', default=False, type=bool, help='resume')
    parser.add_argument('--datasets', default='CIFAR100', type=str, help='classification dataset  (CIFAR10, CIFAR100, ImageNet)')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight_decay')
    parser.add_argument('--save', default='trained', type=str, help='trained.chkpt')
    parser.add_argument('--save_multi', default='trained_multi', type=str, help='trained_multi.chkpt')
    parser.add_argument('--evaluate', default=False, type=bool, help='evaluate')
    parser.add_argument('--reduction_ratio', default=16, type=int, help='reduction_ratio')
    parser.add_argument('--dilation_value', default=4, type=int, help='reduction_ratio')
    args = parser.parse_args()
    args.arch = args.arch.lower()
    args.backbone = args.backbone.lower()
    args.datasets = args.datasets.lower()

    if not os.path.isdir('checkpoints'):
        os.mkdir('checkpoints')
    # To-do: Write a code relating to seed.

    # use gpu or multi-gpu or not.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_multi_gpu = torch.cuda.device_count() > 1
    print('[Info] device:{} use_multi_gpu:{}'.format(device, use_multi_gpu))

    if args.datasets == 'cifar10':
        num_classes = 10
    elif args.datasets == 'cifar100':
        num_classes = 100
    elif args.datasets == 'imagenet':
        num_classes = 1000

    # load the data.
    print('[Info] Load the data.')
    train_loader, valid_loader, train_size, valid_size = prepare_dataloaders(args)

    # load the model.
    print('[Info] Load the model.')

    if args.backbone == 'resnet18':
        model = resnet.resnet18(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet34':
        model = resnet.resnet34(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet50':
        model = resnet.resnet50(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet101':
        model = resnet.resnet101(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)
    elif args.backbone == 'resnet152':
        model = resnet.resnet152(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value)


    model = model.to(device)
    if use_multi_gpu : model = torch.nn.DataParallel(model)
    print('[Info] Total parameters {} '.format(count_parameters(model)))
    # define loss function.
    criterion = torch.nn.CrossEntropyLoss().to(device)

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume:
        # Load the checkpoint.
        print('[Info] Loading checkpoint.')
        if torch.cuda.device_count() > 1:
            checkpoint = load_checkpoint(args.save)
        else:
            checkpoint = load_checkpoint(args.save)

        backbone = checkpoint['backbone']
        args.epoch = checkpoint['epoch']
        state_dict = checkpoint['state_dict']
        model.load_state_dict(state_dict)
        print('[Info] epoch {} backbone {}'.format(args.epoch, backbone))

    # run evaluate.
    if args.evaluate:
        _ = run_epoch(model, 'valid', [args.epoch, args.epoch], criterion, optimizer, valid_loader, valid_size, device)
        return

    # run train.
    best_acc1 = 0.
    for e in range(args.epoch, args.n_epochs + 1):
        adjust_learning_rate(optimizer, e, args)

        # train for one epoch
        _ = run_epoch(model, 'train', [e, args.n_epochs], criterion, optimizer, train_loader, train_size, device)

        # evaluate on validation set
        with torch.no_grad():
            acc1 = run_epoch(model, 'valid', [e, args.n_epochs], criterion, optimizer, valid_loader, valid_size, device)

        # Save checkpoint.
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint({
            'epoch': e,
            'backbone': args.backbone,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)

        if use_multi_gpu:
            save_checkpoint({
                'epoch': e,
                'backbone': args.backbone,
                'state_dict': model.module.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_multi)

        print('[Info] acc1 {} best@acc1 {}'.format(acc1, best_acc1))