示例#1
0
def create_model(num_classes, engine='torch'):

    if engine == 'torch':
        if args.arch == 'peleenet':
            model = PeleeNet(num_classes=num_classes)
        else:
            print("=> unsupported model '{}'. creating PeleeNet by default.".
                  format(args.arch))
            model = PeleeNet(num_classes=num_classes)

        # print(model)

        model = torch.nn.DataParallel(model).cuda()

        if args.weights:
            if os.path.isfile(args.weights):
                print("=> loading checkpoint '{}'".format(args.weights))
                checkpoint = torch.load(args.weights)
                model.load_state_dict(checkpoint['state_dict'])

            else:
                print("=> no checkpoint found at '{}'".format(args.weights))

        cudnn.benchmark = True

    else:
        # create caffe model
        import caffe
        caffe.set_mode_gpu()
        caffe.set_device(0)

        model_def = args.deploy
        model_weights = args.weights

        model = caffe.Net(
            model_def,  # defines the structure of the model
            model_weights,  # contains the trained weights
            caffe.TEST)  # use test mode (e.g., don't perform dropout)

    return model
示例#2
0
def main():
    global args, best_acc1
    args = parser.parse_args()
    print('args:', args)

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Val data loading
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(args.input_dim + 32),
            transforms.CenterCrop(args.input_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    num_classes = len(val_dataset.classes)
    print('Total classes: ', num_classes)

    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch == 'peleenet':
        model = PeleeNet(num_classes=num_classes)
    else:
        print(
            "=> unsupported model '{}'. creating PeleeNet by default.".format(
                args.arch))
        model = PeleeNet(num_classes=num_classes)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        # DataParallel will divide
        model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.pretrained:
        if os.path.isfile(args.weights):
            checkpoint = torch.load(args.weights,
                                    map_location=torch.device('cpu'))
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {}, acc@1 {})".format(
                args.pretrained, checkpoint['epoch'], checkpoint['best_acc1']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    if args.tune:
        model.eval()
        model.module.fuse_model()
        import ilit
        tuner = ilit.Tuner("./conf.yaml")
        q_model = tuner.tune(model)
        exit(0)

    # Training data loading
    traindir = os.path.join(args.data, 'train')

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.input_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion)

        # remember best Acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)