Beispiel #1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--world_size',
                        type=int,
                        default=1,
                        help='number of GPUs to use')

    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd',
                        type=float,
                        default=1e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--lr-decay-every',
                        type=int,
                        default=100,
                        help='learning rate decay by 10 every X epochs')
    parser.add_argument('--lr-decay-scalar',
                        type=float,
                        default=0.1,
                        help='--')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--run_test',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='run test only')

    parser.add_argument(
        '--limit_training_batches',
        type=int,
        default=-1,
        help='how many batches to do per training, -1 means as many as possible'
    )

    parser.add_argument('--no_grad_clip',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='turn off gradient clipping')

    parser.add_argument('--get_flops',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='add hooks to compute flops')

    parser.add_argument(
        '--get_inference_time',
        default=False,
        type=str2bool,
        nargs='?',
        help='runs valid multiple times and reports the result')

    parser.add_argument('--mgpu',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use data paralization via multiple GPUs')

    parser.add_argument('--dataset',
                        default="MNIST",
                        type=str,
                        help='dataset for experiment, choice: MNIST, CIFAR10',
                        choices=["MNIST", "CIFAR10", "Imagenet"])

    parser.add_argument('--data',
                        metavar='DIR',
                        default='/imagenet',
                        help='path to imagenet dataset')

    parser.add_argument(
        '--model',
        default="lenet3",
        type=str,
        help='model selection, choices: lenet3, vgg, mobilenetv2, resnet18',
        choices=[
            "lenet3", "vgg", "mobilenetv2", "resnet18", "resnet152",
            "resnet50", "resnet50_noskip", "resnet20", "resnet34", "resnet101",
            "resnet101_noskip", "densenet201_imagenet", 'densenet121_imagenet'
        ])

    parser.add_argument('--tensorboard',
                        type=str2bool,
                        nargs='?',
                        help='Log progress to TensorBoard')

    parser.add_argument(
        '--save_models',
        default=True,
        type=str2bool,
        nargs='?',
        help='if True, models will be saved to the local folder')

    # ============================PRUNING added
    parser.add_argument(
        '--pruning_config',
        default=None,
        type=str,
        help=
        'path to pruning configuration file, will overwrite all pruning parameters in arguments'
    )

    parser.add_argument('--group_wd_coeff',
                        type=float,
                        default=0.0,
                        help='group weight decay')
    parser.add_argument('--name',
                        default='test',
                        type=str,
                        help='experiment name(folder) to store logs')

    parser.add_argument(
        '--augment',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'enable or not augmentation of training dataset, only for CIFAR, def False'
    )

    parser.add_argument('--load_model',
                        default='',
                        type=str,
                        help='path to model weights')

    parser.add_argument('--pruning',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not pruning, def False')

    parser.add_argument(
        '--pruning-threshold',
        '--pt',
        default=100.0,
        type=float,
        help=
        'Max error perc on validation set while pruning (default: 100.0 means always prune)'
    )

    parser.add_argument(
        '--pruning-momentum',
        default=0.0,
        type=float,
        help=
        'Use momentum on criteria between pruning iterations, def 0.0 means no momentum'
    )

    parser.add_argument('--pruning-step',
                        default=15,
                        type=int,
                        help='How often to check loss and do pruning step')

    parser.add_argument('--prune_per_iteration',
                        default=10,
                        type=int,
                        help='How many neurons to remove at each iteration')

    parser.add_argument(
        '--fixed_layer',
        default=-1,
        type=int,
        help='Prune only a given layer with index, use -1 to prune all')

    parser.add_argument('--start_pruning_after_n_iterations',
                        default=0,
                        type=int,
                        help='from which iteration to start pruning')

    parser.add_argument('--maximum_pruning_iterations',
                        default=1e8,
                        type=int,
                        help='maximum pruning iterations')

    parser.add_argument('--starting_neuron',
                        default=0,
                        type=int,
                        help='starting position for oracle pruning')

    parser.add_argument('--prune_neurons_max',
                        default=-1,
                        type=int,
                        help='prune_neurons_max')

    parser.add_argument('--pruning-method',
                        default=0,
                        type=int,
                        help='pruning method to be used, see readme.md')

    parser.add_argument('--pruning_fixed_criteria',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not criteria reevaluation, def False')

    parser.add_argument('--fixed_network',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='fix network for oracle or criteria computation')

    parser.add_argument(
        '--zero_lr_for_epochs',
        default=-1,
        type=int,
        help='Learning rate will be set to 0 for given number of updates')

    parser.add_argument(
        '--dynamic_network',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'Creates a new network graph from pruned model, works with ResNet-101 only'
    )

    parser.add_argument('--use_test_as_train',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use testing dataset instead of training')

    parser.add_argument('--pruning_mask_from',
                        default='',
                        type=str,
                        help='path to mask file precomputed')

    parser.add_argument(
        '--compute_flops',
        default=True,
        type=str2bool,
        nargs='?',
        help=
        'if True, will run dummy inference of batch 1 before training to get conv sizes'
    )

    # ============================END pruning added

    best_prec1 = 0
    global global_iteration
    global group_wd_optimizer
    global_iteration = 0

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    args.distributed = args.world_size > 1
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=0)

    device = torch.device("cuda" if use_cuda else "cpu")

    if args.model == "lenet3":
        model = LeNet(dataset=args.dataset)
    elif args.model == "vgg":
        model = vgg11_bn(pretrained=True)
    elif args.model == "resnet18":
        model = PreActResNet18()
    elif (args.model == "resnet50") or (args.model == "resnet50_noskip"):
        if args.dataset == "CIFAR10":
            model = PreActResNet50(dataset=args.dataset)
        else:
            from models.resnet import resnet50
            skip_gate = True
            if "noskip" in args.model:
                skip_gate = False

            if args.pruning_method not in [22, 40]:
                skip_gate = False
            model = resnet50(skip_gate=skip_gate)
    elif args.model == "resnet34":
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet34
            model = resnet34()
    elif "resnet101" in args.model:
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet101
            if args.dataset == "Imagenet":
                classes = 1000

            if "noskip" in args.model:
                model = resnet101(num_classes=classes, skip_gate=False)
            else:
                model = resnet101(num_classes=classes)

    elif args.model == "resnet20":
        if args.dataset == "CIFAR10":
            NotImplementedError(
                "resnet20 is not implemented in the current project")
            # from models.resnet_cifar import resnet20
            # model = resnet20()
    elif args.model == "resnet152":
        model = PreActResNet152()
    elif args.model == "densenet201_imagenet":
        from models.densenet_imagenet import DenseNet201
        model = DenseNet201(gate_types=['output_bn'], pretrained=True)
    elif args.model == "densenet121_imagenet":
        from models.densenet_imagenet import DenseNet121
        model = DenseNet121(gate_types=['output_bn'], pretrained=True)
    else:
        print(args.model, "model is not supported")

    # dataset loading section
    if args.dataset == "MNIST":
        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "CIFAR10":
        # Data loading code
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        if args.augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        kwargs = {'num_workers': 8, 'pin_memory': True}
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            '../data', train=True, download=True, transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   drop_last=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('../data', train=False, transform=transform_test),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "Imagenet":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
        else:
            train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        if args.use_test_as_train:
            train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                **kwargs)

        test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  **kwargs)

    ####end dataset preparation

    if args.dynamic_network:
        # attempts to load pruned model and modify it be removing pruned channels
        # works for resnet101 only
        if (len(args.load_model) > 0) and (args.dynamic_network):
            if os.path.isfile(args.load_model):
                load_model_pytorch(model, args.load_model, args.model)

            else:
                print("=> no checkpoint found at '{}'".format(args.load_model))
                exit()

        dynamic_network_change_local(model)

        # save the model
        log_save_folder = "%s" % args.name
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        if not os.path.exists("%s/models" % (log_save_folder)):
            os.makedirs("%s/models" % (log_save_folder))

        model_save_path = "%s/models/pruned.weights" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint({'state_dict': model_state_dict},
                            False,
                            filename=model_save_path)

    print("model is defined")

    # aux function to get size of feature maps
    # First it adds hooks for each conv layer
    # Then runs inference with 1 image
    output_sizes = get_conv_sizes(args, model)

    if use_cuda and not args.mgpu:
        model = model.to(device)
    elif args.distributed:
        model.cuda()
        print(
            "\n\n WARNING: distributed pruning was not verified and might not work correctly"
        )
        model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.mgpu:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.to(device)

    print(
        "model is set to device: use_cuda {}, args.mgpu {}, agrs.distributed {}"
        .format(use_cuda, args.mgpu, args.distributed))

    weight_decay = args.wd
    if args.fixed_network:
        weight_decay = 0.0

    # remove updates from gate layers, because we want them to be 0 or 1 constantly
    if 1:
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

    total_size_params = sum(
        [np.prod(par.shape) for par in parameters_for_update])
    print("Total number of parameters, w/o usage of bn consts: ",
          total_size_params)

    optimizer = optim.SGD(parameters_for_update,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=weight_decay)

    if 1:
        # helping optimizer to implement group lasso (with very small weight that doesn't affect training)
        # will be used to calculate number of remaining flops and parameters in the network
        group_wd_optimizer = group_lasso_decay(
            parameters_for_update,
            group_lasso_weight=args.group_wd_coeff,
            named_parameters=parameters_for_update_named,
            output_sizes=output_sizes)

    cudnn.benchmark = True

    # define objective
    criterion = nn.CrossEntropyLoss()

    ###=======================added for pruning
    # logging part
    log_save_folder = "%s" % args.name
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    if not os.path.exists("%s/models" % (log_save_folder)):
        os.makedirs("%s/models" % (log_save_folder))

    train_writer = None
    if args.tensorboard:
        try:
            # tensorboardX v1.6
            train_writer = SummaryWriter(log_dir="%s" % (log_save_folder))
        except:
            # tensorboardX v1.7
            train_writer = SummaryWriter(logdir="%s" % (log_save_folder))

    time_point = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    textfile = "%s/log_%s.txt" % (log_save_folder, time_point)
    stdout = Logger(textfile)
    sys.stdout = stdout
    print(" ".join(sys.argv))

    # initializing parameters for pruning
    # we can add weights of different layers or we can add gates (multiplies output with 1, useful only for gradient computation)
    pruning_engine = None
    if args.pruning:
        pruning_settings = dict()
        if not (args.pruning_config is None):
            pruning_settings_reader = PruningConfigReader()
            pruning_settings_reader.read_config(args.pruning_config)
            pruning_settings = pruning_settings_reader.get_parameters()

        # overwrite parameters from config file with those from command line
        # needs manual entry here
        # user_specified = [key for key in vars(default_args).keys() if not (vars(default_args)[key]==vars(args)[key])]
        # argv_of_interest = ['pruning_threshold', 'pruning-momentum', 'pruning_step', 'prune_per_iteration',
        #                     'fixed_layer', 'start_pruning_after_n_iterations', 'maximum_pruning_iterations',
        #                     'starting_neuron', 'prune_neurons_max', 'pruning_method']

        has_attribute = lambda x: any([x in a for a in sys.argv])

        if has_attribute('pruning-momentum'):
            pruning_settings['pruning_momentum'] = vars(
                args)['pruning_momentum']
        if has_attribute('pruning-method'):
            pruning_settings['method'] = vars(args)['pruning_method']

        pruning_parameters_list = prepare_pruning_list(
            pruning_settings,
            model,
            model_name=args.model,
            pruning_mask_from=args.pruning_mask_from,
            name=args.name)
        print("Total pruning layers:", len(pruning_parameters_list))

        folder_to_write = "%s" % log_save_folder + "/"
        log_folder = folder_to_write

        pruning_engine = pytorch_pruning(pruning_parameters_list,
                                         pruning_settings=pruning_settings,
                                         log_folder=log_folder)

        pruning_engine.connect_tensorboard(train_writer)
        pruning_engine.dataset = args.dataset
        pruning_engine.model = args.model
        pruning_engine.pruning_mask_from = args.pruning_mask_from
        pruning_engine.load_mask()
        gates_to_params = connect_gates_with_parameters_for_flops(
            args.model, parameters_for_update_named)
        pruning_engine.gates_to_params = gates_to_params

    ###=======================end for pruning
    # loading model file
    if (len(args.load_model) > 0) and (not args.dynamic_network):
        if os.path.isfile(args.load_model):
            load_model_pytorch(model, args.load_model, args.model)
        else:
            print("=> no checkpoint found at '{}'".format(args.load_model))
            exit()

    if args.tensorboard and 0:
        if args.dataset == "CIFAR10":
            dummy_input = torch.rand(1, 3, 32, 32).to(device)
        elif args.dataset == "Imagenet":
            dummy_input = torch.rand(1, 3, 224, 224).to(device)

        train_writer.add_graph(model, dummy_input)

    for epoch in range(1, args.epochs + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(args, optimizer, epoch, args.zero_lr_for_epochs,
                             train_writer)

        if not args.run_test and not args.get_inference_time:
            train(args,
                  model,
                  device,
                  train_loader,
                  optimizer,
                  epoch,
                  criterion,
                  train_writer=train_writer,
                  pruning_engine=pruning_engine)

        if args.pruning:
            # skip validation error calculation and model saving
            if pruning_engine.method == 50: continue

        # evaluate on validation set
        prec1, _ = validate(args,
                            test_loader,
                            model,
                            device,
                            criterion,
                            epoch,
                            train_writer=train_writer)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        model_save_path = "%s/models/checkpoint.weights" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=model_save_path)
Beispiel #2
0
def run(args):
    torch.manual_seed(args.local_rank)

    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    if args.arch_name == "resnet50v15":
        from models.resnetv15 import build_resnet
        net = build_resnet("resnet50", "classic")
    else:
        print(
            "loading torchvision model for inversion with the name: {}".format(
                args.arch_name))
        net = models.__dict__[args.arch_name](pretrained=True)

    net = net.to(device)

    use_fp16 = args.fp16
    if use_fp16:
        net, _ = amp.initialize(net, [], opt_level="O2")

    print('==> Resuming from checkpoint..')

    ### load models
    if args.arch_name == "resnet50v15":
        path_to_model = "./models/resnet50v15/model_best.pth.tar"
        load_model_pytorch(net,
                           path_to_model,
                           gpu_n=torch.cuda.current_device())

    ### load feature statistics
    print('==> Getting BN params as feature statistics')
    feature_statistics = list()
    for module in net.modules():
        if isinstance(module, nn.BatchNorm2d):
            mean, std = module.running_mean.data, module.running_var.data
            if use_fp16:
                feature_statistics.append(
                    (std.type(torch.float16), mean.type(torch.float16)))
            else:
                feature_statistics.append((std, mean))

    net.to(device)

    net.eval()

    # reserved to compute test accuracy on generated images by different networks
    net_verifier = None
    if args.verifier and args.adi_scale == 0:
        # if multiple GPUs are used then we can change code to load different verifiers to different GPUs
        if args.local_rank == 0:
            print("loading verifier: ", args.verifier_arch)
            net_verifier = models.__dict__[args.verifier_arch](
                pretrained=True).to(device)
            net_verifier.eval()

            if use_fp16:
                net_verifier = net_verifier.half()

    if args.adi_scale != 0.0:
        student_arch = "resnet18"
        net_verifier = models.__dict__[student_arch](
            pretrained=True).to(device)
        net_verifier.eval()

        if use_fp16:
            net_verifier, _ = amp.initialize(net_verifier, [], opt_level="O2")

        net_verifier = net_verifier.to(device)
        net_verifier.train()

        if use_fp16:
            for module in net_verifier.modules():
                if isinstance(module, nn.BatchNorm2d):
                    module.eval().half()

    from deepinversion import DeepInversionClass

    exp_name = args.exp_name
    # final images will be stored here:
    adi_data_path = "./final_images/%s" % exp_name
    # temporal data and generations will be stored here
    exp_name = "generations/%s" % exp_name

    args.iterations = 2000
    args.start_noise = True
    # args.detach_student = False

    args.resolution = 224
    bs = args.bs
    jitter = 30

    parameters = dict()
    parameters["resolution"] = 224
    parameters["random_label"] = False
    parameters["start_noise"] = True
    parameters["detach_student"] = False
    parameters["do_flip"] = True

    parameters["do_flip"] = args.do_flip
    parameters["random_label"] = args.random_label
    parameters["store_best_images"] = args.store_best_images

    criterion = nn.CrossEntropyLoss()

    coefficients = dict()
    coefficients["r_feature"] = args.r_feature
    coefficients["tv_l1"] = args.tv_l1
    coefficients["tv_l2"] = args.tv_l2
    coefficients["l2"] = args.l2
    coefficients["lr"] = args.lr
    coefficients["main_loss_multiplier"] = args.main_loss_multiplier
    coefficients["adi_scale"] = args.adi_scale

    network_output_function = lambda x: x

    # check accuracy of verifier
    if args.verifier:
        hook_for_display = lambda x, y: validate_one(x, y, net_verifier)
    else:
        hook_for_display = None

    DeepInversionEngine = DeepInversionClass(
        net_teacher=net,
        final_data_path=adi_data_path,
        path=exp_name,
        parameters=parameters,
        setting_id=args.setting_id,
        bs=bs,
        use_fp16=args.fp16,
        jitter=jitter,
        criterion=criterion,
        coefficients=coefficients,
        network_output_function=network_output_function,
        hook_for_display=hook_for_display)
    net_student = None
    if args.adi_scale != 0:
        net_student = net_verifier
    DeepInversionEngine.generate_batch(net_student=net_student)