Exemplo n.º 1
0
def get_models(args, train=True, as_ensemble=False, model_file=None, leaky_relu=False):
    models = []
    
    mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda()
    std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda()
    normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

    if model_file:
        state_dict = torch.load(model_file)
        if train:
            print('Loading pre-trained models...')
    
    iter_m = state_dict.keys() if model_file else range(args.model_num)

    for i in iter_m:
        if args.arch.lower() == 'resnet':
            model = ResNet(depth=args.depth, leaky_relu=leaky_relu)
        else:
            raise ValueError('[{:s}] architecture is not supported yet...')
        # we include input normalization as a part of the model
        model = ModelWrapper(model, normalizer)
        if model_file:
            model.load_state_dict(state_dict[i])
        if train:
            model.train()
        else:
            model.eval()
        model = model.cuda()
        models.append(model)

    if as_ensemble:
        assert not train, 'Must be in eval mode when getting models to form an ensemble'
        ensemble = Ensemble(models)
        ensemble.eval()
        return ensemble
    else:
        return models
Exemplo n.º 2
0
def main(args):
    '''
    main function of FALCON
    
    :param args: arguments for a model
    '''

    # choose dataset
    if args.datasets == "svhn":
        num_classes = 10
    elif args.datasets == "cifar100":
        num_classes = 100
    else:
        pass

    # choose model ResNet
    if "ResNet" in args.model:
        if args.convolution == "FALCON":
            net = ResNet(layer_num=str(args.layer_num),
                         num_classes=num_classes)
            if args.is_train:
                if args.alpha == 1:
                    if args.init:
                        load_specific_model(net,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                else:
                    if args.init:
                        net2 = ResNet(layer_num=str(args.layer_num),
                                      num_classes=num_classes)
                        net = ResNet(layer_num=str(args.layer_num),
                                     num_classes=num_classes,
                                     alpha=args.alpha)
                        load_specific_model(net2,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net = init_with_alpha_resnet(net2, net, args.alpha)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net = ResNet(layer_num=str(args.layer_num),
                                     num_classes=num_classes,
                                     alpha=args.alpha)
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
            else:
                if args.alpha == 1:
                    net = ResNet(layer_num=str(args.layer_num),
                                 num_classes=num_classes,
                                 alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
                else:
                    net = ResNet(layer_num=str(args.layer_num),
                                 num_classes=num_classes,
                                 alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
        elif args.convolution == "StConvBranch":
            net = ResNet_StConv_branch(layer_num=str(args.layer_num),
                                       num_classes=num_classes,
                                       alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = ResNet_StConv_branch(layer_num=str(args.layer_num),
                                       num_classes=num_classes,
                                       alpha=args.alpha)
            if args.init:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=False,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = ResNet(layer_num=str(args.layer_num),
                         num_classes=num_classes)
        else:
            pass
    # choose model VGG
    elif "VGG" in args.model:
        if args.convolution == "FALCON":
            net = VGG(num_classes=num_classes, which=args.model)
            if args.is_train:
                if args.alpha == 1:
                    if args.init:
                        load_specific_model(net,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                else:
                    if args.init:
                        net2 = VGG(num_classes=num_classes, which=args.model)
                        net = VGG(num_classes=num_classes,
                                  which=args.model,
                                  alpha=args.alpha)
                        load_specific_model(net2,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net = init_with_alpha_vgg(net2, net, args.alpha)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net = VGG(num_classes=num_classes,
                                  which=args.model,
                                  alpha=args.alpha)
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
            else:
                if args.alpha == 1:
                    net = VGG(num_classes=num_classes,
                              which=args.model,
                              alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
                else:
                    net = VGG(num_classes=num_classes,
                              which=args.model,
                              alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)

        elif args.convolution == 'StConvBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
            if args.init:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=args.is_train,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = VGG(num_classes=num_classes, which=args.model)
        else:
            pass
    else:
        pass

    net = net.cuda()

    print_model_parm_nums(net)
    print_model_parm_flops(net)

    if args.is_train:
        # training
        best = train(net,
                     lr=args.learning_rate,
                     optimizer_option=args.optimizer,
                     epochs=args.epochs,
                     batch_size=args.batch_size,
                     is_train=args.is_train,
                     data=args.datasets,
                     lrd=args.lr_decay_rate)
        if not args.not_save:
            save_specific_model(best, args)
        test(net, batch_size=args.batch_size, data=args.datasets)
    else:
        # testing
        load_specific_model(net, args, input_path=args.restore_path)
        inference_time = 0
        inference_time += \
            test(net, batch_size=args.batch_size, data=args.datasets)
#        print("Average Inference Time: %f" % (float(inference_time) / float(1)))

# calculate number of parameters & FLOPs
    print_model_parm_nums(net)
    print_model_parm_flops(net)

    # time of forwarding 100 data sample (ms)
    x = torch.rand(100, 3, 32, 32)
    x = Variable(x.cuda())
    net(x)
    timer = Timer()
    timer.tic()
    for _ in range(100):
        net(x)
    timer.toc()
Exemplo n.º 3
0
                                         num_workers=4,
                                         pin_memory=True,
                                         sampler=None)

    assert pred['wnids'][:1000] == train_wnids

    model = ResNet('resnet50', 1000)
    sd = model.resnet_base.state_dict()
    sd.update(torch.load('materials/resnet50-base.pth'))
    model.resnet_base.load_state_dict(sd)

    fcw = pred['pred'][:1000].cpu()
    model.fc.weight = nn.Parameter(fcw[:, :-1])
    model.fc.bias = nn.Parameter(fcw[:, -1])

    model = model.cuda()
    model.train()

    optimizer = torch.optim.SGD(model.resnet_base.parameters(),
                                lr=0.0001,
                                momentum=0.9)
    loss_fn = nn.CrossEntropyLoss().cuda()

    keep_ratio = 0.9975
    trlog = {}
    trlog['loss'] = []
    trlog['acc'] = []

    for epoch in range(1, 9999):

        ave_loss = None
Exemplo n.º 4
0
def get_model(config, num_class=10, bn_types=None, data_parallel=True):
    name = config.model
    print('model name: {}'.format(name))
    print('bn_types: {}'.format(bn_types))
    if name == 'resnet50':
        if bn_types is None:
            model = ResNet(dataset='imagenet',
                           depth=50,
                           num_classes=num_class,
                           bottleneck=True)
        else:
            model = ResNetMultiBN(dataset='imagenet',
                                  depth=50,
                                  num_classes=num_class,
                                  bn_types=bn_types,
                                  bottleneck=True)
    elif name == 'resnet200':
        if bn_types is None:
            model = ResNet(dataset='imagenet',
                           depth=200,
                           num_classes=num_class,
                           bottleneck=True)
        else:
            model = ResNetMultiBN(dataset='imagenet',
                                  depth=200,
                                  num_classes=num_class,
                                  bn_types=bn_types,
                                  bottleneck=True)
    elif name == 'wresnet40_2':
        if bn_types is None:
            model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class)
        else:
            raise Exception('unimplemented error')
    elif name == 'wresnet28_10':
        if bn_types is None:
            model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class)
        else:
            model = WideResNetMultiBN(28,
                                      10,
                                      dropout_rate=0.0,
                                      num_classes=num_class,
                                      bn_types=bn_types)
    elif name == 'shakeshake26_2x32d':
        if bn_types is None:
            model = ShakeResNet(26, 32, num_class)
        else:
            model = ShakeResNetMultiBN(26, 32, num_class, bn_types)
    elif name == 'shakeshake26_2x64d':
        if bn_types is None:
            model = ShakeResNet(26, 64, num_class)
        else:
            model = ShakeResNetMultiBN(26, 64, num_class, bn_types)
    elif name == 'shakeshake26_2x96d':
        if bn_types is None:
            model = ShakeResNet(26, 96, num_class)
        else:
            model = ShakeResNetMultiBN(26, 96, num_class, bn_types)
    elif name == 'shakeshake26_2x112d':
        if bn_types is None:
            model = ShakeResNet(26, 112, num_class)
        else:
            model = ShakeResNetMultiBN(26, 112, num_class, bn_types)
    elif name == 'shakeshake26_2x96d_next':
        if bn_types is None:
            model = ShakeResNeXt(26, 96, 4, num_class)
        else:
            raise Exception('unimplemented error')

    elif name == 'pyramid':
        if bn_types is None:
            model = PyramidNet('cifar10',
                               depth=config.pyramidnet_depth,
                               alpha=config.pyramidnet_alpha,
                               num_classes=num_class,
                               bottleneck=True)
        else:
            model = PyramidNetMultiBN('cifar10',
                                      depth=config.pyramidnet_depth,
                                      alpha=config.pyramidnet_alpha,
                                      num_classes=num_class,
                                      bottleneck=True,
                                      bn_types=bn_types)
    else:
        raise NameError('no model named, %s' % name)

    if data_parallel:
        model = model.cuda()
        model = DataParallel(model)
    else:
        import horovod.torch as hvd
        device = torch.device('cuda', hvd.local_rank())
        model = model.to(device)
    cudnn.benchmark = True
    return model
Exemplo n.º 5
0
def main(args):

    # choose dataset
    if args.datasets == "cifar10" or args.datasets == "svhn" or args.datasets == "mnist":
        num_classes = 10
    elif args.datasets == "cifar100":
        num_classes = 100
    else:
        pass

    # choose model
    if "ResNet" in args.model:
        if args.convolution == "FALCON":
            net = ResNet(layer_num="34", num_classes=num_classes)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StandardConv',
                                    input_path=args.stconv_path)
                net.falcon(rank=args.rank,
                           init=args.init,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
            else:
                net.falcon(rank=args.rank,
                           init=False,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
        elif args.convolution == "StConvBranch":
            net = ResNet_StConv_branch(layer_num='34',
                                       num_classes=num_classes,
                                       alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = ResNet_StConv_branch(layer_num='34',
                                       num_classes=num_classes,
                                       alpha=args.alpha)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=False,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = ResNet(layer_num="34", num_classes=num_classes)
        else:
            pass
    elif "VGG" in args.model:
        if args.convolution == "FALCON":
            net = VGG(num_classes=num_classes, which=args.model)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StandardConv',
                                    input_path=args.stconv_path)
                net.falcon(rank=args.rank,
                           init=args.init,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
            else:
                net.falcon(rank=args.rank,
                           init=False,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
        elif args.convolution == 'StConvBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=args.is_train,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = VGG(num_classes=num_classes, which=args.model)
        else:
            pass
    else:
        pass

    net = net.cuda()

    print_model_parm_nums(net)
    print_model_parm_flops(net)

    if args.is_train:
        # training
        best = train(net,
                     lr=args.learning_rate,
                     optimizer_option=args.optimizer,
                     epochs=args.epochs,
                     batch_size=args.batch_size,
                     is_train=args.is_train,
                     data=args.datasets,
                     lrd=args.lr_decay_rate)
        if not args.not_save:
            save_specific_model(best, args)
        test(net, batch_size=args.batch_size, data=args.datasets)
    else:
        # testing
        load_specific_model(net, args, input_path=args.restore_path)
        inference_time = 0
        for i in range(1):
            inference_time += test(net,
                                   batch_size=args.batch_size,
                                   data=args.datasets)
        print("Average Inference Time: %f" %
              (float(inference_time) / float(1)))

    # calculate number of parameters & FLOPs
    print_model_parm_nums(net)
    print_model_parm_flops(net)

    # time of forwarding 100 data sample (ms)
    x = torch.rand(100, 3, 32, 32)
    x = Variable(x.cuda())
    net(x)
    timer = Timer()
    timer.tic()
    for i in range(100):
        net(x)
    timer.toc()
    print('Do once forward need %.3f ms.' % (timer.total_time * 1000 / 100.0))