Exemple #1
0
def objective(space):
    block_kernel1, block_stride1, block_kernel2, block_stride2, kernel_size1, stride1, learning_rate = space

    block_kernel1 = int(block_kernel1)
    block_stride1 = int(block_stride1)
    block_kernel2 = int(block_kernel2)
    block_stride2 = int(block_stride2)
    kernel_size1 = int(kernel_size1)
    stride1 = int(stride1)
    learning_rate = float(learning_rate)

    block = Block(in_planes=64, out_planes=64, block_kernel1=block_kernel1, block_stride1=block_stride1, 
                  block_kernel2=block_kernel2, block_stride2=block_stride2)
    net = MobileNet(block, kernel_size1=kernel_size1, stride1=stride1)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    total_loss = 0.0
    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, optimizer, criterion)
        test_loss = test(epoch, net)
        total_loss += test_loss

    return total_loss / args.num_epochs
Exemple #2
0
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
    if dataset == 'imagenet':
        n_class = 1000
    elif dataset == 'cifar10':
        n_class = 10
    else:
        raise ValueError('unsupported dataset')

    if model == 'mobilenet':
        from mobilenet import MobileNet
        net = MobileNet(n_class=n_class)
    elif model == 'mobilenetv2':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=n_class)
    elif model.startswith('resnet'):
        net = resnet.__dict__[model](pretrained=True)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, n_class)
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net
Exemple #3
0
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
    if model == 'mobilenet' and dataset == 'imagenet':
        from mobilenet import MobileNet
        net = MobileNet(n_class=1000)
    elif model == 'mobilenetv2' and dataset == 'imagenet':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=1000)
    elif model == 'mobilenet' and dataset == 'cifar10':
        from mobilenet import MobileNet
        net = MobileNet(n_class=10)
    elif model == 'mobilenetv2' and dataset == 'cifar10':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=10)
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net
Exemple #4
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'resnet18':
        net = ResNet18()
    elif args.net == 'resnetcbam18':
        net = ResNetCBAM18()
    elif args.net == 'resnetzam18':
        net = ResNetZAM18()
    elif args.net == 'mobilenet':
        net = MobileNet()
    elif args.net == 'mobilenetcbam':
        net = MobileNetCBAM()
    elif args.net == 'mobilenetzam':
        net = MobileNetZAM()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()
    
    if use_gpu:
        net = net.cuda()

    return net
Exemple #5
0
def main():
    global opt, start_epoch, best_prec1
    opt = cfg
    opt.gpuids = list(map(int, opt.gpuids))

    if opt.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    model = MobileNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay,
                          nesterov=True)
    start_epoch = 0

    ckpt_file = join("model", opt.ckpt)

    if opt.cuda:
        torch.cuda.set_device(opt.gpuids[0])
        with torch.cuda.device(opt.gpuids[0]):
            model = model.cuda()
            criterion = criterion.cuda()
        model = nn.DataParallel(model,
                                device_ids=opt.gpuids,
                                output_device=opt.gpuids[0])
        cudnn.benchmark = True

    # for resuming training
    if opt.resume:
        if isfile(ckpt_file):
            print("==> Loading Checkpoint '{}'".format(opt.ckpt))
            if opt.cuda:
                checkpoint = torch.load(ckpt_file,
                                        map_location=lambda storage, loc:
                                        storage.cuda(opt.gpuids[0]))
                try:
                    model.module.load_state_dict(checkpoint['model'])
                except:
                    model.load_state_dict(checkpoint['model'])
            else:
                checkpoint = torch.load(
                    ckpt_file, map_location=lambda storage, loc: storage)
                try:
                    model.load_state_dict(checkpoint['model'])
                except:
                    # create new OrderedDict that does not contain `module.`
                    new_state_dict = OrderedDict()
                    for k, v in checkpoint['model'].items():
                        if k[:7] == 'module.':
                            name = k[7:]  # remove `module.`
                        else:
                            name = k[:]
                        new_state_dict[name] = v

                    model.load_state_dict(new_state_dict)

            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])

            print("==> Loaded Checkpoint '{}' (epoch {})".format(
                opt.ckpt, start_epoch))
        else:
            print("==> no checkpoint found at '{}'".format(opt.ckpt))
            return

    # Download & Load Dataset
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.workers)

    valset = torchvision.datasets.CIFAR10(root='./data',
                                          train=False,
                                          download=True,
                                          transform=transform_val)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=opt.test_batch_size,
                                             shuffle=False,
                                             num_workers=opt.workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # for evaluation
    if opt.eval:
        if isfile(ckpt_file):
            print("==> Loading Checkpoint '{}'".format(opt.ckpt))
            if opt.cuda:
                checkpoint = torch.load(ckpt_file,
                                        map_location=lambda storage, loc:
                                        storage.cuda(opt.gpuids[0]))
                try:
                    model.module.load_state_dict(checkpoint['model'])
                except:
                    model.load_state_dict(checkpoint['model'])
            else:
                checkpoint = torch.load(
                    ckpt_file, map_location=lambda storage, loc: storage)
                try:
                    model.load_state_dict(checkpoint['model'])
                except:
                    # create new OrderedDict that does not contain `module.`
                    new_state_dict = OrderedDict()
                    for k, v in checkpoint['model'].items():
                        if k[:7] == 'module.':
                            name = k[7:]  # remove `module.`
                        else:
                            name = k[:]
                        new_state_dict[name] = v

                    model.load_state_dict(new_state_dict)

            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])

            print("==> Loaded Checkpoint '{}' (epoch {})".format(
                opt.ckpt, start_epoch))

            # evaluate on validation set
            print("\n===> [ Evaluation ]")
            start_time = time.time()
            prec1 = validate(val_loader, model, criterion)
            elapsed_time = time.time() - start_time
            print("====> {:.2f} seconds to evaluate this model\n".format(
                elapsed_time))
            return
        else:
            print("==> no checkpoint found at '{}'".format(opt.ckpt))
            return

    # train...
    train_time = 0.0
    validate_time = 0.0
    for epoch in range(start_epoch, opt.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\n==> Epoch: {}, lr = {}'.format(
            epoch, optimizer.param_groups[0]["lr"]))

        # train for one epoch
        print("===> [ Training ]")
        start_time = time.time()
        train(train_loader, model, criterion, optimizer, epoch)
        elapsed_time = time.time() - start_time
        train_time += elapsed_time
        print(
            "====> {:.2f} seconds to train this epoch\n".format(elapsed_time))

        # evaluate on validation set
        print("===> [ Validation ]")
        start_time = time.time()
        prec1 = validate(val_loader, model, criterion)
        elapsed_time = time.time() - start_time
        validate_time += elapsed_time
        print("====> {:.2f} seconds to validate this epoch\n".format(
            elapsed_time))

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        state = {
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_model(state, epoch, is_best)

    avg_train_time = train_time / opt.epochs
    avg_valid_time = validate_time / opt.epochs
    total_train_time = train_time + validate_time
    print("====> average training time per epoch: {}m {:.2f}s".format(
        int(avg_train_time // 60), avg_train_time % 60))
    print("====> average validation time per epoch: {}m {:.2f}s".format(
        int(avg_valid_time // 60), avg_valid_time % 60))
    print("====> training time: {}m {:.2f}s".format(int(train_time // 60),
                                                    train_time % 60))
    print("====> validation time: {}m {:.2f}s".format(int(validate_time // 60),
                                                      validate_time % 60))
    print("====> total training time: {}m {:.2f}s".format(
        int(total_train_time // 60), total_train_time % 60))