def main(args): ''' main function of FALCON :param args: arguments for a model ''' # choose dataset if args.datasets == "svhn": num_classes = 10 elif args.datasets == "cifar100": num_classes = 100 else: pass # choose model ResNet if "ResNet" in args.model: if args.convolution == "FALCON": net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes) if args.is_train: if args.alpha == 1: if args.init: load_specific_model(net, args, convolution='StandardConv', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: if args.init: net2 = ResNet(layer_num=str(args.layer_num), num_classes=num_classes) net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) load_specific_model(net2, args, convolution='StandardConv', input_path=args.stconv_path) net = init_with_alpha_resnet(net2, net, args.alpha) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: if args.alpha == 1: net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StConvBranch": net = ResNet_StConv_branch(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) elif args.convolution == 'FALCONBranch': net = ResNet_StConv_branch(layer_num=str(args.layer_num), num_classes=num_classes, alpha=args.alpha) if args.init: load_specific_model(net, args, convolution='StConvBranch', input_path=args.stconv_path) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StandardConv": net = ResNet(layer_num=str(args.layer_num), num_classes=num_classes) else: pass # choose model VGG elif "VGG" in args.model: if args.convolution == "FALCON": net = VGG(num_classes=num_classes, which=args.model) if args.is_train: if args.alpha == 1: if args.init: load_specific_model(net, args, convolution='StandardConv', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: if args.init: net2 = VGG(num_classes=num_classes, which=args.model) net = VGG(num_classes=num_classes, which=args.model, alpha=args.alpha) load_specific_model(net2, args, convolution='StandardConv', input_path=args.stconv_path) net = init_with_alpha_vgg(net2, net, args.alpha) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net = VGG(num_classes=num_classes, which=args.model, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: if args.alpha == 1: net = VGG(num_classes=num_classes, which=args.model, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) else: net = VGG(num_classes=num_classes, which=args.model, alpha=args.alpha) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == 'StConvBranch': net = VGG_StConv_branch(num_classes=num_classes, which=args.model, alpha=args.alpha) elif args.convolution == 'FALCONBranch': net = VGG_StConv_branch(num_classes=num_classes, which=args.model, alpha=args.alpha) if args.init: load_specific_model(net, args, convolution='StConvBranch', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.is_train, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StandardConv": net = VGG(num_classes=num_classes, which=args.model) else: pass else: pass net = net.cuda() print_model_parm_nums(net) print_model_parm_flops(net) if args.is_train: # training best = train(net, lr=args.learning_rate, optimizer_option=args.optimizer, epochs=args.epochs, batch_size=args.batch_size, is_train=args.is_train, data=args.datasets, lrd=args.lr_decay_rate) if not args.not_save: save_specific_model(best, args) test(net, batch_size=args.batch_size, data=args.datasets) else: # testing load_specific_model(net, args, input_path=args.restore_path) inference_time = 0 inference_time += \ test(net, batch_size=args.batch_size, data=args.datasets) # print("Average Inference Time: %f" % (float(inference_time) / float(1))) # calculate number of parameters & FLOPs print_model_parm_nums(net) print_model_parm_flops(net) # time of forwarding 100 data sample (ms) x = torch.rand(100, 3, 32, 32) x = Variable(x.cuda()) net(x) timer = Timer() timer.tic() for _ in range(100): net(x) timer.toc()
def main(args): # choose dataset if args.datasets == "cifar10" or args.datasets == "svhn" or args.datasets == "mnist": num_classes = 10 elif args.datasets == "cifar100": num_classes = 100 else: pass # choose model if "ResNet" in args.model: if args.convolution == "FALCON": net = ResNet(layer_num="34", num_classes=num_classes) if args.is_train: load_specific_model(net, args, convolution='StandardConv', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StConvBranch": net = ResNet_StConv_branch(layer_num='34', num_classes=num_classes, alpha=args.alpha) elif args.convolution == 'FALCONBranch': net = ResNet_StConv_branch(layer_num='34', num_classes=num_classes, alpha=args.alpha) if args.is_train: load_specific_model(net, args, convolution='StConvBranch', input_path=args.stconv_path) net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StandardConv": net = ResNet(layer_num="34", num_classes=num_classes) else: pass elif "VGG" in args.model: if args.convolution == "FALCON": net = VGG(num_classes=num_classes, which=args.model) if args.is_train: load_specific_model(net, args, convolution='StandardConv', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.init, bn=args.bn, relu=args.relu, groups=args.groups) else: net.falcon(rank=args.rank, init=False, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == 'StConvBranch': net = VGG_StConv_branch(num_classes=num_classes, which=args.model, alpha=args.alpha) elif args.convolution == 'FALCONBranch': net = VGG_StConv_branch(num_classes=num_classes, which=args.model, alpha=args.alpha) if args.is_train: load_specific_model(net, args, convolution='StConvBranch', input_path=args.stconv_path) net.falcon(rank=args.rank, init=args.is_train, bn=args.bn, relu=args.relu, groups=args.groups) elif args.convolution == "StandardConv": net = VGG(num_classes=num_classes, which=args.model) else: pass else: pass net = net.cuda() print_model_parm_nums(net) print_model_parm_flops(net) if args.is_train: # training best = train(net, lr=args.learning_rate, optimizer_option=args.optimizer, epochs=args.epochs, batch_size=args.batch_size, is_train=args.is_train, data=args.datasets, lrd=args.lr_decay_rate) if not args.not_save: save_specific_model(best, args) test(net, batch_size=args.batch_size, data=args.datasets) else: # testing load_specific_model(net, args, input_path=args.restore_path) inference_time = 0 for i in range(1): inference_time += test(net, batch_size=args.batch_size, data=args.datasets) print("Average Inference Time: %f" % (float(inference_time) / float(1))) # calculate number of parameters & FLOPs print_model_parm_nums(net) print_model_parm_flops(net) # time of forwarding 100 data sample (ms) x = torch.rand(100, 3, 32, 32) x = Variable(x.cuda()) net(x) timer = Timer() timer.tic() for i in range(100): net(x) timer.toc() print('Do once forward need %.3f ms.' % (timer.total_time * 1000 / 100.0))