Esempio n. 1
0
def main(args):
    '''
    main function of FALCON
    
    :param args: arguments for a model
    '''

    # choose dataset
    if args.datasets == "svhn":
        num_classes = 10
    elif args.datasets == "cifar100":
        num_classes = 100
    else:
        pass

    # choose model ResNet
    if "ResNet" in args.model:
        if args.convolution == "FALCON":
            net = ResNet(layer_num=str(args.layer_num),
                         num_classes=num_classes)
            if args.is_train:
                if args.alpha == 1:
                    if args.init:
                        load_specific_model(net,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                else:
                    if args.init:
                        net2 = ResNet(layer_num=str(args.layer_num),
                                      num_classes=num_classes)
                        net = ResNet(layer_num=str(args.layer_num),
                                     num_classes=num_classes,
                                     alpha=args.alpha)
                        load_specific_model(net2,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net = init_with_alpha_resnet(net2, net, args.alpha)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net = ResNet(layer_num=str(args.layer_num),
                                     num_classes=num_classes,
                                     alpha=args.alpha)
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
            else:
                if args.alpha == 1:
                    net = ResNet(layer_num=str(args.layer_num),
                                 num_classes=num_classes,
                                 alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
                else:
                    net = ResNet(layer_num=str(args.layer_num),
                                 num_classes=num_classes,
                                 alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
        elif args.convolution == "StConvBranch":
            net = ResNet_StConv_branch(layer_num=str(args.layer_num),
                                       num_classes=num_classes,
                                       alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = ResNet_StConv_branch(layer_num=str(args.layer_num),
                                       num_classes=num_classes,
                                       alpha=args.alpha)
            if args.init:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=False,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = ResNet(layer_num=str(args.layer_num),
                         num_classes=num_classes)
        else:
            pass
    # choose model VGG
    elif "VGG" in args.model:
        if args.convolution == "FALCON":
            net = VGG(num_classes=num_classes, which=args.model)
            if args.is_train:
                if args.alpha == 1:
                    if args.init:
                        load_specific_model(net,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                else:
                    if args.init:
                        net2 = VGG(num_classes=num_classes, which=args.model)
                        net = VGG(num_classes=num_classes,
                                  which=args.model,
                                  alpha=args.alpha)
                        load_specific_model(net2,
                                            args,
                                            convolution='StandardConv',
                                            input_path=args.stconv_path)
                        net = init_with_alpha_vgg(net2, net, args.alpha)
                        net.falcon(rank=args.rank,
                                   init=args.init,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
                    else:
                        net = VGG(num_classes=num_classes,
                                  which=args.model,
                                  alpha=args.alpha)
                        net.falcon(rank=args.rank,
                                   init=False,
                                   bn=args.bn,
                                   relu=args.relu,
                                   groups=args.groups)
            else:
                if args.alpha == 1:
                    net = VGG(num_classes=num_classes,
                              which=args.model,
                              alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)
                else:
                    net = VGG(num_classes=num_classes,
                              which=args.model,
                              alpha=args.alpha)
                    net.falcon(rank=args.rank,
                               init=False,
                               bn=args.bn,
                               relu=args.relu,
                               groups=args.groups)

        elif args.convolution == 'StConvBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
            if args.init:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=args.is_train,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = VGG(num_classes=num_classes, which=args.model)
        else:
            pass
    else:
        pass

    net = net.cuda()

    print_model_parm_nums(net)
    print_model_parm_flops(net)

    if args.is_train:
        # training
        best = train(net,
                     lr=args.learning_rate,
                     optimizer_option=args.optimizer,
                     epochs=args.epochs,
                     batch_size=args.batch_size,
                     is_train=args.is_train,
                     data=args.datasets,
                     lrd=args.lr_decay_rate)
        if not args.not_save:
            save_specific_model(best, args)
        test(net, batch_size=args.batch_size, data=args.datasets)
    else:
        # testing
        load_specific_model(net, args, input_path=args.restore_path)
        inference_time = 0
        inference_time += \
            test(net, batch_size=args.batch_size, data=args.datasets)
#        print("Average Inference Time: %f" % (float(inference_time) / float(1)))

# calculate number of parameters & FLOPs
    print_model_parm_nums(net)
    print_model_parm_flops(net)

    # time of forwarding 100 data sample (ms)
    x = torch.rand(100, 3, 32, 32)
    x = Variable(x.cuda())
    net(x)
    timer = Timer()
    timer.tic()
    for _ in range(100):
        net(x)
    timer.toc()
Esempio n. 2
0
def main(args):

    # choose dataset
    if args.datasets == "cifar10" or args.datasets == "svhn" or args.datasets == "mnist":
        num_classes = 10
    elif args.datasets == "cifar100":
        num_classes = 100
    else:
        pass

    # choose model
    if "ResNet" in args.model:
        if args.convolution == "FALCON":
            net = ResNet(layer_num="34", num_classes=num_classes)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StandardConv',
                                    input_path=args.stconv_path)
                net.falcon(rank=args.rank,
                           init=args.init,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
            else:
                net.falcon(rank=args.rank,
                           init=False,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
        elif args.convolution == "StConvBranch":
            net = ResNet_StConv_branch(layer_num='34',
                                       num_classes=num_classes,
                                       alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = ResNet_StConv_branch(layer_num='34',
                                       num_classes=num_classes,
                                       alpha=args.alpha)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=False,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = ResNet(layer_num="34", num_classes=num_classes)
        else:
            pass
    elif "VGG" in args.model:
        if args.convolution == "FALCON":
            net = VGG(num_classes=num_classes, which=args.model)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StandardConv',
                                    input_path=args.stconv_path)
                net.falcon(rank=args.rank,
                           init=args.init,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
            else:
                net.falcon(rank=args.rank,
                           init=False,
                           bn=args.bn,
                           relu=args.relu,
                           groups=args.groups)
        elif args.convolution == 'StConvBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
        elif args.convolution == 'FALCONBranch':
            net = VGG_StConv_branch(num_classes=num_classes,
                                    which=args.model,
                                    alpha=args.alpha)
            if args.is_train:
                load_specific_model(net,
                                    args,
                                    convolution='StConvBranch',
                                    input_path=args.stconv_path)
            net.falcon(rank=args.rank,
                       init=args.is_train,
                       bn=args.bn,
                       relu=args.relu,
                       groups=args.groups)
        elif args.convolution == "StandardConv":
            net = VGG(num_classes=num_classes, which=args.model)
        else:
            pass
    else:
        pass

    net = net.cuda()

    print_model_parm_nums(net)
    print_model_parm_flops(net)

    if args.is_train:
        # training
        best = train(net,
                     lr=args.learning_rate,
                     optimizer_option=args.optimizer,
                     epochs=args.epochs,
                     batch_size=args.batch_size,
                     is_train=args.is_train,
                     data=args.datasets,
                     lrd=args.lr_decay_rate)
        if not args.not_save:
            save_specific_model(best, args)
        test(net, batch_size=args.batch_size, data=args.datasets)
    else:
        # testing
        load_specific_model(net, args, input_path=args.restore_path)
        inference_time = 0
        for i in range(1):
            inference_time += test(net,
                                   batch_size=args.batch_size,
                                   data=args.datasets)
        print("Average Inference Time: %f" %
              (float(inference_time) / float(1)))

    # calculate number of parameters & FLOPs
    print_model_parm_nums(net)
    print_model_parm_flops(net)

    # time of forwarding 100 data sample (ms)
    x = torch.rand(100, 3, 32, 32)
    x = Variable(x.cuda())
    net(x)
    timer = Timer()
    timer.tic()
    for i in range(100):
        net(x)
    timer.toc()
    print('Do once forward need %.3f ms.' % (timer.total_time * 1000 / 100.0))