Пример #1
0
def get_classifier(mode, n_classes=10):
    if mode == 'resnet18':
        classifier = ResNet18(num_classes=n_classes)
    elif mode == 'resnet34':
        classifier = ResNet34(num_classes=n_classes)
    elif mode == 'resnet50':
        classifier = ResNet50(num_classes=n_classes)
    elif mode == 'resnet18_imagenet':
        classifier = resnet18(num_classes=n_classes)
    elif mode == 'resnet50_imagenet':
        classifier = resnet50(num_classes=n_classes)
    else:
        raise NotImplementedError()

    return classifier
Пример #2
0
def get_classifier(mode, n_classes=10):
    if mode == 'resnet18':
        classifier = ResNet18(num_classes=n_classes)
    elif mode == 'resnet34':
        classifier = ResNet34(num_classes=n_classes)
    elif mode == 'resnet50':
        classifier = ResNet50(num_classes=n_classes)
    elif mode == 'resnet18_imagenet':
        classifier = resnet18(num_classes=n_classes)
    elif mode == 'resnet50_imagenet':
        classifier = resnet50(num_classes=n_classes)
    elif mode == 'live':
        classifier = FeatherNet(input_size=128, se=True, avgdown=True)
#        classifier = ResNet18(num_classes=n_classes)
#        classifier = LiveModel()
    else:
        raise NotImplementedError()

    return classifier
Пример #3
0
    def get_classifier(self, is_multi_class):
        if self.params.model == 'resnet18':
            from models.resnet import ResNet18
            classifier = ResNet18(num_classes=self.params.n_classes)
        elif self.params.model == 'resnet34':
            from models.resnet import ResNet34
            classifier = ResNet34(num_classes=self.params.n_classes)
        elif self.params.model == 'resnet50':
            from models.resnet import ResNet50
            classifier = ResNet50(num_classes=self.params.n_classes)
        elif self.params.model == 'resnet18_imagenet':
            if is_multi_class:
                from models.resnet_imagenet_multiclass_infer import resnet18
            else:
                from models.resnet_imagenet import resnet18
            classifier = resnet18(num_classes=self.params.n_classes)
        elif self.params.model == 'resnet50_imagenet':
            from models.resnet_imagenet import resnet50
            classifier = resnet50(num_classes=self.params.n_classes)
        else:
            raise NotImplementedError()

        return classifier
Пример #4
0
                      top1=top1,
                      top5=top5) + time_string())
    print(
        '  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, losses.avg


print("=> creating model '{}'".format(args.arch))

if args.dataset == 'imagenet':

    if args.arch == 'resnet101':
        net = resnet_imagenet.resnet101()
    elif args.arch == 'resnet50':
        net = resnet_imagenet.resnet50()
    elif args.arch == 'resnet34':
        net = resnet_imagenet.resnet34()
    elif args.arch == 'resnet18':
        net = resnet_imagenet.resnet18()
else:
    if args.arch == 'resnet110':
        net = models.resnet110(num_classes=10)
    elif args.arch == 'resnet56':
        net = models.resnet56(num_classes=10)
    elif args.arch == 'resnet32':
        net = models.resnet32(num_classes=10)
    elif args.arch == 'resnet20':
        net = models.resnet20(num_classes=10)

if args.dataset == 'imagenet':
Пример #5
0
from models.resnet_imagenet import resnet50
model = resnet50(pretrained=True)
import torch
#out = model.forward(torch.FloatTensor(np
import numpy as np
out = model.forward(torch.FloatTensor(np.random.randn(1,3,224,224)))
out = model.forward(torch.FloatTensor(np.random.randn(1,3,112,112)))
%history -f tmp.py
Пример #6
0
                                   shuffle=False,
                                   num_workers=12)
testloader = torchdata.DataLoader(testset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=12)

if args.test:
    loader = testloader
else:
    loader = trainloader

if dataset == Data.imagenet:
    # rnet = utils.get_resnet_model(dataset, [])
    # rnet.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'))
    rnet = resnet50(pretrained=True)

else:
    rnet = utils.get_resnet_model(dataset, [18, 18, 18])
    rnetpath = os.path.join(utils.get_path(Data.cifar10, iter=args.iter),
                            'rnet.t7')
    rnet_dict = torch.load(rnetpath)
    rnet.load_state_dict(rnet_dict['net'])

if dataset == Data.imagenet and device == 'cuda':
    rnet = torch.nn.DataParallel(rnet)

rnet.to(device)
num_scales = 3
scale_search()
Пример #7
0
def load_model(opt):

    if opt.from_modelzoo:
        if opt.pretrained:
            print("=> using pre-trained model '{}'".format(opt.arch))
            model = models.__dict__[opt.model_def](pretrained=True)
        else:
            print("=> creating model '{}'".format(opt.arch))
            model = models.__dict__[opt.model_def]()

        return model
    else:
        if opt.pretrained_file != '':
            model = torch.load(opt.pretrained_filedir)
        else:
            if opt.model_def == 'alexnet':
                model = alexnet.Net()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'mobilenet':
                model = mobilenet.Net(nClasses=opt.nclasses,
                                      width_mult=opt.widthmult,
                                      gtp=opt.grouptype,
                                      gsz=opt.sp,
                                      expsz=opt.exp)
                if opt.cuda:
                    model = model.cuda()

            elif opt.model_def == 'alexnetexpander':
                model = alexnetexpander.Net()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'vgg16cifar':
                model = vggcifar.vgg16()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'vgg16cifar_bn':
                model = vggcifar.vgg16_bn()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'vgg16cifarexpander':
                model = vggcifarexpander.vgg16()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'vgg16cifar_bnexpander':
                model = vggcifarexpander.vgg16_bn()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenet_cifar':
                model = densenet_cifar.DenseNet3(opt.layers,
                                                 opt.nclasses,
                                                 opt.growth,
                                                 reduction=opt.reduce,
                                                 bottleneck=opt.bottleneck,
                                                 dropRate=opt.droprate)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetgrouped_cifar':
                model = densenetgrouped_cifar.DenseNet3(
                    opt.layers,
                    opt.nclasses,
                    opt.growth,
                    reduction=opt.reduce,
                    bottleneck=opt.bottleneck,
                    dropRate=opt.droprate)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetexpander_cifar':
                model = densenetexpander_cifar.DenseNet3(
                    opt.layers,
                    opt.nclasses,
                    opt.growth,
                    reduction=opt.reduce,
                    bottleneck=opt.bottleneck,
                    dropRate=opt.droprate,
                    expandSize=opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenet121':
                model = densenet.densenet121()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenet169':
                model = densenet.densenet169()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenet161':
                model = densenet.densenet161()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenet201':
                model = densenet.densenet201()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetexpander121':
                model = densenetexpander.densenet121(expandSize=opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetexpander169':
                model = densenetexpander.densenet169(expandSize=opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetexpander161':
                model = densenetexpander.densenet161(expandSize=opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'densenetexpander201':
                model = densenetexpander.densenet201(expandSize=opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnet34':
                model = resnet.resnet34()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnet50':
                model = resnet.resnet50()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnet101':
                model = resnet.resnet101()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnet152':
                model = resnet.resnet152()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnetexpander34':
                model = resnetexpander.resnet34(opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnetexpander50':
                model = resnetexpander.resnet50(opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnetexpander101':
                model = resnetexpander.resnet101(opt.expandSize)
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnetexpander152':
                model = resnetexpander.resnet152(opt.expandSize)
                if opt.cuda:
                    model = model.cuda()

            elif opt.model_def == 'resnet18':
                model = resnet.resnet18()
                if opt.cuda:
                    model = model.cuda()
            elif opt.model_def == 'resnet50':
                model = resnet.resnet50()
                if opt.cuda:
                    model = model.cuda()

    return model
Пример #8
0
start_epoch = 0

if dataset_type == utils.Data.imagenet:
    num_scales = 2
    pretrained_agent = resnet_imagenet.resnet18(pretrained=True)
    # agent = resnet_imagenet.ResNet(resnet_imagenet.BasicBlock, [2, 2, 2, 2], num_scales)
    agent = resnet_imagenet.scalenet18(pretrained=True,
                                       scalelist=[1, 0.5, 0.25])

    # pretrained_dict = pretrained_agent.state_dict()
    # model_dict = agent.state_dict()
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and not k.startswith('fc')}
    # model_dict.update(pretrained_dict)
    # agent.load_state_dict(model_dict)

    rnet = resnet_imagenet.resnet50(pretrained=True)

else:
    # ScaleNet
    num_scales = 2
    agent = Networks.ScaleNet10(num_scales)

    # configuration of the ResNet
    rnet = resnet.FlatResNet32(base.BasicBlock, [18, 18, 18], num_classes=10)
    rnet_dict = torch.load(os.path.join(rootpath, 'rnet.t7'))
    rnet.load_state_dict(rnet_dict['net'])

    if args.resume:
        # test with trained check-point
        print('load the check point weights...')
        ckpt = torch.load(os.path.join(rootpath, 'scale_latest.t7'))