Exemple #1
0
def settings(args):
    if args.save_folder and not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    if args.log_path:
        set_logger(args.log_path)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        args.cuda = False

    cudnn.benchmark = True

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    args.unfreeze_layers = [
        'layer.0.', 'layer.1.', 'layer.2.', 'layer.3.', 'layer.4.', 'layer.5.',
        'layer.6.', 'layer.7.', 'layer.8.', 'layer.9.', 'layer.10.',
        'layer.11.', 'pooler'
    ]
    args.shared_layers = args.unfreeze_layers
    # preprocess
    if args.build_data_seperate:
        build_data_seperate()

    if args.mode == 'finetune' and args.build_data_file:
        build_data_file(args)
Exemple #2
0
def main():
    """Do stuff."""
    args = parser.parse_args()

    # Don't use this, neither set learning rate as a linear function
    # of the count of gpus, it will make accuracy lower
    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.network_width_multiplier = math.sqrt(args.network_width_multiplier)
    args.max_allowed_network_width_multiplier = math.sqrt(
        args.max_allowed_network_width_multiplier)
    if args.mode == 'prune':
        args.save_folder = os.path.join(args.save_folder,
                                        str(args.target_sparsity))
        if args.initial_sparsity != 0.0:
            args.load_folder = os.path.join(args.load_folder,
                                            str(args.initial_sparsity))

    if args.save_folder and not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    if args.log_path:
        set_logger(args.log_path)

    if args.pruning_ratio_to_acc_record_file and not os.path.isdir(
            args.pruning_ratio_to_acc_record_file.rsplit('/', 1)[0]):
        os.makedirs(args.pruning_ratio_to_acc_record_file.rsplit('/', 1)[0])

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        args.cuda = False

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.benchmark = True

    # If set > 0, will resume training from a given checkpoint.
    resume_from_epoch = 0
    resume_folder = args.load_folder
    for try_epoch in range(200, 0, -1):
        if os.path.exists(
                args.checkpoint_format.format(save_folder=resume_folder,
                                              epoch=try_epoch)):
            resume_from_epoch = try_epoch
            break

    if args.restore_epoch:
        resume_from_epoch = args.restore_epoch

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    if resume_from_epoch:
        print("Resume from epoch: ", resume_from_epoch)
        filepath = args.checkpoint_format.format(save_folder=resume_folder,
                                                 epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        shared_layer_info = checkpoint['shared_layer_info']
        # shared_layer_info[args.dataset]['network_width_multiplier'] = 1.0
        if 'num_for_construct' in checkpoint_keys:
            num_for_construct = checkpoint['num_for_construct']
        if args.mode == 'inference' and 'network_width_multiplier' in shared_layer_info[
                args.dataset]:  # TODO, temporary solution
            args.network_width_multiplier = shared_layer_info[
                args.dataset]['network_width_multiplier']
    else:
        dataset_history = []
        dataset2num_classes = {}
        masks = {}
        shared_layer_info = {}

    if args.baseline_acc_file is None or not os.path.isfile(
            args.baseline_acc_file):
        sys.exit(3)
    with open(args.baseline_acc_file, 'r') as jsonfile:
        json_data = json.load(jsonfile)
        baseline_acc = float(json_data[args.dataset])

    if args.mode == 'prune' and not args.pruning_ratio_to_acc_record_file:
        sys.exit(-1)

    if args.arch == 'resnet18':
        model = models.__dict__[args.arch](
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes,
            network_width_multiplier=args.network_width_multiplier,
            shared_layer_info=shared_layer_info)
    elif 'vgg' in args.arch:
        custom_cfg = [
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
            512, 512, 512, 'M'
        ]
        model = models.__dict__[args.arch](
            custom_cfg,
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes,
            network_width_multiplier=args.network_width_multiplier,
            shared_layer_info=shared_layer_info)
    else:
        print('Error!')
        sys.exit(1)

    # Add and set the model dataset.
    model.add_dataset(args.dataset, args.num_classes)
    model.set_dataset(args.dataset)

    model = nn.DataParallel(model)
    model = model.cuda()
    if not masks:
        for name, module in model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(
                    module, nl.SharableLinear):
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                if 'cuda' in module.weight.data.type():
                    mask = mask.cuda()
                masks[name] = mask
    else:
        # when we expand network, we need to allocate new masks
        NEED_ADJUST_MASK = False
        for name, module in model.named_modules():
            if isinstance(module, nl.SharableConv2d):
                if masks[name].size(1) < module.weight.data.size(1):
                    assert args.mode == 'finetune'
                    NEED_ADJUST_MASK = True
                elif masks[name].size(1) > module.weight.data.size(1):
                    assert args.mode == 'inference'
                    NEED_ADJUST_MASK = True

        if NEED_ADJUST_MASK:
            if args.mode == 'finetune':
                for name, module in model.named_modules():
                    if isinstance(module, nl.SharableConv2d):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].
                             size(1), :, :].copy_(masks[name])
                        masks[name] = mask
                    elif isinstance(module, nl.SharableLinear):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].size(1)].copy_(
                            masks[name])
                        masks[name] = mask
            elif args.mode == 'inference':
                for name, module in model.named_modules():
                    if isinstance(module, nl.SharableConv2d):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:, :, :, :].copy_(
                            masks[name][:mask.size(0), :mask.size(1), :, :])
                        masks[name] = mask
                    elif isinstance(module, nl.SharableLinear):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:, :].copy_(
                            masks[name][:mask.size(0), :mask.size(1)])
                        masks[name] = mask

    if args.dataset not in shared_layer_info:

        shared_layer_info[args.dataset] = {
            'bias': {},
            'bn_layer_running_mean': {},
            'bn_layer_running_var': {},
            'bn_layer_weight': {},
            'bn_layer_bias': {},
            'piggymask': {}
        }

        piggymasks = {}
        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(
                        module, nl.SharableLinear):
                    piggymasks[name] = torch.zeros_like(masks['module.' +
                                                              name],
                                                        dtype=torch.float32)
                    piggymasks[name].fill_(0.01)
                    piggymasks[name] = Parameter(piggymasks[name])
                    module.piggymask = piggymasks[name]
    elif args.finetune_again:
        # reinitialize piggymask
        piggymasks = {}
        for name, module in model.module.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(
                    module, nl.SharableLinear):
                piggymasks[name] = torch.zeros_like(masks['module.' + name],
                                                    dtype=torch.float32)
                piggymasks[name].fill_(0.01)
                piggymasks[name] = Parameter(piggymasks[name])
                module.piggymask = piggymasks[name]
    else:
        # try:
        piggymasks = shared_layer_info[args.dataset]['piggymask']
        # except:
        #    piggymasks = {}
        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(
                        module, nl.SharableLinear):
                    module.piggymask = piggymasks[name]
    shared_layer_info[args.dataset][
        'network_width_multiplier'] = args.network_width_multiplier

    if args.num_classes == 2:
        train_loader = dataset.cifar100_train_loader_two_class(
            args.dataset, args.batch_size)
        val_loader = dataset.cifar100_val_loader_two_class(
            args.dataset, args.val_batch_size)
    elif args.num_classes == 5:
        train_loader = dataset.cifar100_train_loader(args.dataset,
                                                     args.batch_size)
        val_loader = dataset.cifar100_val_loader(args.dataset,
                                                 args.val_batch_size)
    else:
        print("num_classes should be either 2 or 5")
        sys.exit(1)

    # if we are going to save checkpoint in other folder, then we recalculate the starting epoch
    if args.save_folder != args.load_folder:
        start_epoch = 0
    else:
        start_epoch = resume_from_epoch

    curr_prune_step = begin_prune_step = start_epoch * len(train_loader)
    end_prune_step = curr_prune_step + args.pruning_interval * len(
        train_loader)

    manager = Manager(args, model, shared_layer_info, masks, train_loader,
                      val_loader, begin_prune_step, end_prune_step)
    if args.mode == 'inference':
        manager.load_checkpoint_only_for_evaluate(resume_from_epoch,
                                                  resume_folder)
        manager.validate(resume_from_epoch - 1)
        return

    lr = args.lr
    lr_mask = args.lr_mask
    # update all layers
    named_params = dict(model.named_parameters())
    params_to_optimize_via_SGD = []
    named_of_params_to_optimize_via_SGD = []
    masks_to_optimize_via_Adam = []
    named_of_masks_to_optimize_via_Adam = []

    for name, param in named_params.items():
        if 'classifiers' in name:
            if '.{}.'.format(model.module.datasets.index(
                    args.dataset)) in name:
                params_to_optimize_via_SGD.append(param)
                named_of_params_to_optimize_via_SGD.append(name)
            continue
        elif 'piggymask' in name:
            masks_to_optimize_via_Adam.append(param)
            named_of_masks_to_optimize_via_Adam.append(name)
        else:
            params_to_optimize_via_SGD.append(param)
            named_of_params_to_optimize_via_SGD.append(name)

    optimizer_network = optim.SGD(params_to_optimize_via_SGD,
                                  lr=lr,
                                  weight_decay=0.0,
                                  momentum=0.9,
                                  nesterov=True)
    optimizers = Optimizers()
    optimizers.add(optimizer_network, lr)

    if masks_to_optimize_via_Adam:
        optimizer_mask = optim.Adam(masks_to_optimize_via_Adam, lr=lr_mask)
        optimizers.add(optimizer_mask, lr_mask)

    manager.load_checkpoint(optimizers, resume_from_epoch, resume_folder)
    """Performs training."""
    curr_lrs = []
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            curr_lrs.append(param_group['lr'])
            break

    if args.mode == 'prune':
        if 'gradual_prune' in args.load_folder and args.save_folder == args.load_folder:
            args.epochs = 20 + resume_from_epoch
        logging.info('')
        logging.info('Before pruning: ')
        logging.info('Sparsity range: {} -> {}'.format(args.initial_sparsity,
                                                       args.target_sparsity))

        must_pruning_ratio_for_curr_task = 0.0

        json_data = {}
        if os.path.isfile(args.pruning_ratio_to_acc_record_file):
            with open(args.pruning_ratio_to_acc_record_file, 'r') as json_file:
                json_data = json.load(json_file)

        if args.network_width_multiplier >= args.max_allowed_network_width_multiplier and json_data[
                '0.0'] < baseline_acc:
            # If we reach the upperbound and still do not get the accuracy over our target on curr task, we still do pruning
            logging.info(
                'we reach the upperbound and still do not get the accuracy over our target on curr task'
            )
            remain_num_tasks = args.total_num_tasks - len(dataset_history)
            logging.info('remain_num_tasks: {}'.format(remain_num_tasks))
            ratio_allow_for_curr_task = round(1.0 / (remain_num_tasks + 1), 1)
            logging.info('ratio_allow_for_curr_task: {:.4f}'.format(
                ratio_allow_for_curr_task))
            must_pruning_ratio_for_curr_task = 1.0 - ratio_allow_for_curr_task
            if args.initial_sparsity >= must_pruning_ratio_for_curr_task:
                sys.exit(6)

        manager.validate(start_epoch - 1)
        logging.info('')
    elif args.mode == 'finetune':
        if not args.finetune_again:
            manager.pruner.make_finetuning_mask()
            logging.info('Finetune stage...')
        else:
            logging.info('Piggymask Retrain...')
            history_best_avg_val_acc_when_retraining = manager.validate(
                start_epoch - 1)
            num_epochs_that_criterion_does_not_get_better = 0

        stop_lr_mask = True
        if manager.pruner.calculate_curr_task_ratio() == 0.0:
            logging.info(
                'There is no left space in convolutional layer for curr task'
                ', we will try to use prior experience as long as possible')
            stop_lr_mask = False

    for epoch_idx in range(start_epoch, args.epochs):
        avg_train_acc, curr_prune_step = manager.train(optimizers, epoch_idx,
                                                       curr_lrs,
                                                       curr_prune_step)

        avg_val_acc = manager.validate(epoch_idx)

        # if args.mode == 'prune' and (epoch_idx+1) >= (args.pruning_interval + start_epoch) and (
        #     avg_val_acc > history_best_avg_val_acc_when_prune):
        #     pass
        if args.finetune_again:
            if avg_val_acc > history_best_avg_val_acc_when_retraining:
                history_best_avg_val_acc_when_retraining = avg_val_acc

                num_epochs_that_criterion_does_not_get_better = 0
                if args.save_folder is not None:
                    print("Removing saved checkpoint")
                    for path in os.listdir(args.save_folder):
                        if '.pth.tar' in path:
                            os.remove(os.path.join(args.save_folder, path))
                else:
                    print('Something is wrong! Block the program with pdb')
                    pdb.set_trace()

                history_best_avg_val_acc = avg_val_acc
                manager.save_checkpoint(optimizers, epoch_idx,
                                        args.save_folder)
            else:
                num_epochs_that_criterion_does_not_get_better += 1

            if args.finetune_again and num_epochs_that_criterion_does_not_get_better == 5:
                saved = False
                for try_epoch in range(200, 0, -1):
                    if os.path.exists(
                            args.checkpoint_format.format(
                                save_folder=args.save_folder,
                                epoch=try_epoch)):
                        saved = True
                        print("Found saved checkpoint")
                        break
                if not saved:
                    print("No saved checkpoint..")
                    manager.save_checkpoint(optimizers, epoch_idx,
                                            args.save_folder)
                logging.info("stop retraining")
                sys.exit(0)

        if args.mode == 'finetune':
            if epoch_idx + 1 == 50 or epoch_idx + 1 == 80:
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']
            if len(optimizers.lrs) == 2:
                if epoch_idx + 1 == 50:
                    for param_group in optimizers[1].param_groups:
                        param_group['lr'] *= 0.2
                if stop_lr_mask and epoch_idx + 1 == 70:
                    for param_group in optimizers[1].param_groups:
                        param_group['lr'] *= 0.0

                curr_lrs[1] = param_group['lr']

    if args.save_folder is not None:
        pass
    #     paths = os.listdir(args.save_folder)
    #     if paths and '.pth.tar' in paths[0]:
    #         for checkpoint_file in paths:
    #             os.remove(os.path.join(args.save_folder, checkpoint_file))
    else:
        print('Something is wrong! Block the program with pdb')
        pdb.set_trace()

    if avg_train_acc > 0.95:
        manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
    else:
        logging.info(f"Training Accuracy goal not met ({avg_train_acc})!")
        if args.dataset == "aquatic_mammals" and avg_train_acc > 0.85:
            logging.info("Saving model...")
            manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
        else:
            logging.info("Not saving model...")

    logging.info('-' * 16)

    if args.pruning_ratio_to_acc_record_file:
        json_data = {}
        if os.path.isfile(args.pruning_ratio_to_acc_record_file):
            with open(args.pruning_ratio_to_acc_record_file, 'r') as json_file:
                json_data = json.load(json_file)

        if args.mode == 'finetune' and not args.test_piggymask:
            json_data[0.0] = round(avg_val_acc, 4)
            with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                json.dump(json_data, json_file)
            if avg_train_acc > 0.95 and avg_val_acc >= baseline_acc:
                print("Pass!")
                pass
            elif args.network_width_multiplier >= args.max_allowed_network_width_multiplier and avg_val_acc < baseline_acc:
                print("Option 2")
                if manager.pruner.calculate_curr_task_ratio() == 0.0:
                    sys.exit(5)
                else:
                    sys.exit(0)
            else:
                print("Option 3")
                if args.network_width_multiplier >= args.max_allowed_network_width_multiplier:
                    print("Network Cannot Expand Anymore!")
                    logging.info("Network Cannot Expand Anymore!")
                    if manager.pruner.calculate_curr_task_ratio() == 0.0:
                        sys.exit(5)
                    else:
                        sys.exit(0)

                else:
                    logging.info("It's time to expand the Network")
                    logging.info('Auto expand network')
                    sys.exit(2)

            if manager.pruner.calculate_curr_task_ratio() == 0.0:
                logging.info(
                    'There is no left space in convolutional layer for curr task, so needless to prune'
                )
                sys.exit(5)

        elif args.mode == 'prune':
            if avg_train_acc > 0.95:
                json_data[args.target_sparsity] = round(avg_val_acc, 4)
                with open(args.pruning_ratio_to_acc_record_file,
                          'w') as json_file:
                    json.dump(json_data, json_file)
            else:
                sys.exit(6)

            must_pruning_ratio_for_curr_task = 0.0

            if args.network_width_multiplier >= args.max_allowed_network_width_multiplier and json_data[
                    '0.0'] < baseline_acc:
                # If we reach the upperbound and still do not get the accuracy over our target on curr task, we still do pruning
                logging.info(
                    'we reach the upperbound and still do not get the accuracy over our target on curr task'
                )
                remain_num_tasks = args.total_num_tasks - len(dataset_history)
                logging.info('remain_num_tasks: {}'.format(remain_num_tasks))
                ratio_allow_for_curr_task = round(1.0 / (remain_num_tasks + 1),
                                                  1)
                logging.info('ratio_allow_for_curr_task: {:.4f}'.format(
                    ratio_allow_for_curr_task))
                must_pruning_ratio_for_curr_task = 1.0 - ratio_allow_for_curr_task
                if args.target_sparsity >= must_pruning_ratio_for_curr_task:
                    sys.exit(6)
Exemple #3
0
def main():
    """Do stuff."""
    args = parser.parse_args()
    if args.save_folder and not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        args.cuda = False

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.benchmark = True

    # If set > 0, will resume training from a given checkpoint.
    resume_from_epoch = 0
    resume_folder = args.load_folder
    for try_epoch in range(200, 0, -1):
        if os.path.exists(
                args.checkpoint_format.format(save_folder=resume_folder,
                                              epoch=try_epoch)):
            resume_from_epoch = try_epoch
            break

    if args.restore_epoch:
        resume_from_epoch = args.restore_epoch

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    if resume_from_epoch:
        filepath = args.checkpoint_format.format(save_folder=resume_folder,
                                                 epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        if 'shared_layer_info' in checkpoint_keys:
            shared_layer_info = checkpoint['shared_layer_info']
        else:
            shared_layer_info = {}

        if 'num_for_construct' in checkpoint_keys:
            num_for_construct = checkpoint['num_for_construct']
    else:
        dataset_history = []
        dataset2num_classes = {}
        masks = {}
        shared_layer_info = {}

    if args.arch == 'vgg16_bn_cifar100':
        model = packnet_models.__dict__[args.arch](
            pretrained=False,
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes)
    elif args.arch == 'resnet18':
        model = packnet_models.__dict__[args.arch](
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes)
    else:
        print('Error!')
        sys.exit(0)

    # Add and set the model dataset
    model.add_dataset(args.dataset, args.num_classes)
    model.set_dataset(args.dataset)

    if args.dataset not in shared_layer_info:
        shared_layer_info[args.dataset] = {
            'conv_bias': {},
            'bn_layer_running_mean': {},
            'bn_layer_running_var': {},
            'bn_layer_weight': {},
            'bn_layer_bias': {},
            'fc_bias': {}
        }

    model = nn.DataParallel(model)
    model = model.cuda()
    if args.initial_from_task and 'None' not in args.initial_from_task:
        filepath = ''
        for try_epoch in range(200, 0, -1):
            if os.path.exists(
                    args.checkpoint_format.format(
                        save_folder=args.initial_from_task, epoch=try_epoch)):
                filepath = args.checkpoint_format.format(
                    save_folder=args.initial_from_task, epoch=try_epoch)
                break
        if filepath == '':
            pdb.set_trace()
            print('Something is wrong')
        checkpoint = torch.load(filepath)
        state_dict = checkpoint['model_state_dict']
        curr_model_state_dict = model.module.state_dict()

        for name, param in state_dict.items():
            if 'num_batches_tracked' in name:
                continue
            try:
                curr_model_state_dict[name][:].copy_(param)
            except:
                pdb.set_trace()
                print('here')

    if not masks:
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                if 'cuda' in module.weight.data.type():
                    mask = mask.cuda()
                masks[name] = mask

    if args.num_classes == 5:
        train_loader = dataset.cifar100_train_loader(args.dataset,
                                                     args.batch_size)
        val_loader = dataset.cifar100_val_loader(args.dataset,
                                                 args.val_batch_size)
    else:
        print("num_classes should be 5")
        sys.exit(1)

    # if we are going to save checkpoint in other folder, then we recalculate the starting epoch
    if args.save_folder != args.load_folder:
        start_epoch = 0
    else:
        start_epoch = resume_from_epoch

    manager = Manager(args, model, shared_layer_info, masks, train_loader,
                      val_loader)

    if args.mode == 'inference':
        print(f"Evaluating for task: {args.dataset}")
        manager.load_checkpoint_for_inference(resume_from_epoch, resume_folder)
        val_acc = manager.validate(resume_from_epoch - 1)
        if args.logfile:
            json_data = {}
            if os.path.isfile(args.logfile):
                with open(args.logfile) as json_file:
                    json_data = json.load(json_file)
            json_data[args.dataset] = '{:.4f}'.format(val_acc)
            with open(args.logfile, 'w') as json_file:
                json.dump(json_data, json_file)
        return

    lr = args.lr
    # update all layers
    named_params = dict(model.named_parameters())
    params_to_optimize_via_SGD = []
    named_params_to_optimize_via_SGD = []
    masks_to_optimize_via_SGD = []
    named_masks_to_optimize_via_SGD = []

    for tuple_ in named_params.items():
        if 'classifiers' in tuple_[0]:
            if '.{}.'.format(model.module.datasets.index(
                    args.dataset)) in tuple_[0]:
                params_to_optimize_via_SGD.append(tuple_[1])
                named_params_to_optimize_via_SGD.append(tuple_)
            continue
        else:
            params_to_optimize_via_SGD.append(tuple_[1])
            named_params_to_optimize_via_SGD.append(tuple_)

    # here we must set weight decay to 0.0,
    # because the weight decay strategy in build-in step() function will change every weight elem in the tensor,
    # which will hurt previous tasks' accuracy. (Instead, we do weight decay ourself in the `prune.py`)
    optimizer_network = optim.SGD(params_to_optimize_via_SGD,
                                  lr=lr,
                                  weight_decay=0.0,
                                  momentum=0.9,
                                  nesterov=True)

    optimizers = Optimizers()
    optimizers.add(optimizer_network, lr)

    manager.load_checkpoint(optimizers, resume_from_epoch, resume_folder)
    """Performs training."""
    curr_lrs = []
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            curr_lrs.append(param_group['lr'])
            break

    if args.mode == 'prune':
        print()
        print('Sparsity ratio: {}'.format(args.one_shot_prune_perc))
        print('Before pruning: ')
        baseline_acc = manager.validate(start_epoch - 1)
        print('Execute one shot pruning ...')
        manager.one_shot_prune(args.one_shot_prune_perc)
    elif args.mode == 'finetune':
        manager.pruner.make_finetuning_mask()

    for epoch_idx in range(start_epoch, args.epochs):
        avg_train_acc = manager.train(optimizers, epoch_idx, curr_lrs)
        avg_val_acc = manager.validate(epoch_idx)

        if args.mode == 'finetune':
            if epoch_idx + 1 == 50 or epoch_idx + 1 == 80:
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']

        if args.mode == 'prune':
            if epoch_idx + 1 == 25:
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']

    if args.save_folder is not None:
        #     paths = os.listdir(args.save_folder)
        #     if paths and '.pth.tar' in paths[0]:
        #         for checkpoint_file in paths:
        #             os.remove(os.path.join(args.save_folder, checkpoint_file))
        pass
    else:
        print('Something is wrong! Block the program with pdb')
        pdb.set_trace()

    if args.mode == 'finetune':
        manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
        if args.logfile:
            json_data = {}
            if os.path.isfile(args.logfile):
                with open(args.logfile) as json_file:
                    json_data = json.load(json_file)
            json_data[args.dataset] = '{:.4f}'.format(avg_val_acc)
            with open(args.logfile, 'w') as json_file:
                json.dump(json_data, json_file)

        if avg_train_acc < 0.97:
            print('Cannot prune any more!')

    elif args.mode == 'prune':
        #if avg_train_acc > 0.97 and (avg_val_acc - baseline_acc) >= -0.01:
        #if avg_train_acc > 0.97:
        print(f"Avg Train Acc: {avg_train_acc}")
        manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
        #else:
        #   print('Pruning too much!')

    print('-' * 16)
Exemple #4
0
def main():
    """Do stuff."""
    args = FLAGS.parse_args()

    if args.mode == 'pack':
        assert args.packlist and args.maskloc
        dataset2masks = {}
        dataset2classifiers = {}
        net_type = None

        # Location to output stats.
        fout = open(args.maskloc[:-2] + 'txt', 'w')

        # Load models one by one and store their masks.
        fin = open(args.packlist, 'r')
        counter = 1
        for idx, line in enumerate(fin):
            if not line or not line.strip() or line[0] == '#':
                continue
            dataset, loadname = line.split(':')
            loadname = loadname.strip()

            # Can't have same dataset twice.
            if dataset in dataset2masks:
                ValueError('Repeated datasets as input...')
            print('Loading model #%d for dataset "%s"' % (counter, dataset))
            counter += 1
            ckpt = torch.load(loadname)
            model = ckpt['model']
            # Ensure all inputs are for same model type.
            if net_type is None:
                net_type = str(type(model))
            else:
                assert net_type == str(type(model)), '%s != %s' % (
                    net_type, str(type(model)))

            # Gather masks and store in dictionary.
            fout.write('Dataset: %s\n' % (dataset))
            total_params, neg_params, zerod_params = [], [], []
            masks = {}
            for module_idx, module in enumerate(model.shared.modules()):
                if 'ElementWise' in str(type(module)):
                    mask = module.threshold_fn(module.mask_real)
                    mask = mask.data.cpu()

                    # Make sure mask values are in {0, 1} or {-1, 0, 1}.
                    num_zero = mask.eq(0).sum()
                    num_one = mask.eq(1).sum()
                    num_mone = mask.eq(-1).sum()
                    total = mask.numel()
                    threshold_type = module.threshold_fn.__class__.__name__
                    if threshold_type == 'Binarizer':
                        assert num_mone == 0
                        assert num_zero + num_one == total
                    elif threshold_type == 'Ternarizer':
                        assert num_mone + num_zero + num_one == total
                    masks[module_idx] = mask.type(torch.ByteTensor)

                    # Count total and zerod out params.
                    total_params.append(total)
                    zerod_params.append(num_zero)
                    neg_params.append(num_mone)
                    fout.write('%d\t%.2f%%\t%.2f%%\n' % (
                        module_idx,
                        neg_params[-1] / total_params[-1] * 100,
                        zerod_params[-1] / total_params[-1] * 100))
            print('Check Passed: Masks only have binary/ternary values.')
            dataset2masks[dataset] = masks
            dataset2classifiers[dataset] = model.classifier

            fout.write('Total -1: %d/%d = %.2f%%\n' % (
                sum(neg_params), sum(total_params), sum(neg_params) / sum(total_params) * 100))
            fout.write('Total 0: %d/%d = %.2f%%\n' % (
                sum(zerod_params), sum(total_params), sum(zerod_params) / sum(total_params) * 100))
            fout.write('-' * 20 + '\n')

        # Clean up and save masks to file.
        fin.close()
        fout.close()
        torch.save({
            'dataset2masks': dataset2masks,
            'dataset2classifiers': dataset2classifiers,
        }, args.maskloc)

    elif args.mode == 'eval':
        assert args.arch and args.maskloc and args.dataset

        # Set default train and test path if not provided as input.
        utils.set_dataset_paths(args)

        # Load masks and classifier for this dataset.
        info = torch.load(args.maskloc)
        if args.dataset not in info['dataset2masks']:
            ValueError('%s not found in masks.' % (args.dataset))
        masks = info['dataset2masks'][args.dataset]
        classifier = info['dataset2classifiers'][args.dataset]

        # Create the vanilla model and apply masking.
        model = None
        if args.arch == 'vgg16':
            model = net.ModifiedVGG16(original=True)
        elif args.arch == 'vgg16bn':
            model = net.ModifiedVGG16BN(original=True)
        elif args.arch == 'resnet50':
            model = net.ModifiedResNet(original=True)
        elif args.arch == 'densenet121':
            model = net.ModifiedDenseNet(original=True)
        elif args.arch == 'resnet50_diff':
            assert args.source
            model = net.ResNetDiffInit(args.source, original=True)
        model.eval()

        print('Applying masks.')
        for module_idx, module in enumerate(model.shared.modules()):
            if module_idx in masks:
                mask = masks[module_idx]
                module.weight.data[mask.eq(0)] = 0
                module.weight.data[mask.eq(-1)] *= -1
        print('Applied masks.')

        # Override model.classifier with saved one.
        model.add_dataset(args.dataset, classifier.weight.size(0))
        model.set_dataset(args.dataset)
        model.classifier = classifier
        if args.cuda:
            model = model.cuda()

        # Create the manager and run eval.
        manager = Manager(args, model)
        manager.eval()
def main():
    """Do stuff."""
    args = parser.parse_args()

    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.network_width_multiplier = math.sqrt(args.network_width_multiplier)

    if args.mode == 'prune':
        args.save_folder = os.path.join(args.save_folder,
                                        str(args.target_sparsity))
        if args.initial_sparsity != 0.0:
            args.load_folder = os.path.join(args.load_folder,
                                            str(args.initial_sparsity))

    if args.pruning_ratio_to_acc_record_file and not os.path.isdir(
            args.pruning_ratio_to_acc_record_file.rsplit('/', 1)[0]):
        os.makedirs(args.pruning_ratio_to_acc_record_file.rsplit('/', 1)[0])

    if args.save_folder and not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    if args.log_path:
        set_logger(args.log_path)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        args.cuda = False

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.benchmark = True

    # If set > 0, will resume training from a given checkpoint.
    resume_from_epoch = 0
    resume_folder = args.load_folder
    for try_epoch in range(200, 0, -1):
        if os.path.exists(
                args.checkpoint_format.format(save_folder=resume_folder,
                                              epoch=try_epoch)):
            resume_from_epoch = try_epoch
            break

    if args.restore_epoch:
        resume_from_epoch = args.restore_epoch

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    if resume_from_epoch:
        filepath = args.checkpoint_format.format(save_folder=resume_folder,
                                                 epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        shared_layer_info = checkpoint['shared_layer_info']
        if 'num_for_construct' in checkpoint_keys:
            num_for_construct = checkpoint['num_for_construct']
        if args.mode == 'inference' and 'network_width_multiplier' in shared_layer_info[
                args.dataset]:
            args.network_width_multiplier = shared_layer_info[
                args.dataset]['network_width_multiplier']
    else:
        dataset_history = []
        dataset2num_classes = {}
        masks = {}
        shared_layer_info = {}

    if args.arch == 'resnet50':
        # num_for_construct = [64, 64, 64*4, 128, 128*4, 256, 256*4, 512, 512*4]
        model = models.__dict__[args.arch](
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes,
            network_width_multiplier=args.network_width_multiplier,
            shared_layer_info=shared_layer_info)
    elif 'vgg' in args.arch:
        custom_cfg = [
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
            512, 512, 512, 'M'
        ]
        model = models.__dict__[args.arch](
            custom_cfg,
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes,
            network_width_multiplier=args.network_width_multiplier,
            shared_layer_info=shared_layer_info,
            progressive_init=args.progressive_init)
    else:
        print('Error!')
        sys.exit(1)

    # Add and set the model dataset.
    model.add_dataset(args.dataset, args.num_classes)
    model.set_dataset(args.dataset)

    # Move model to GPU
    model = nn.DataParallel(model)
    model = model.cuda()

    # For datasets whose image_size is 224 and also the first task
    if args.use_imagenet_pretrained and model.module.datasets.index(
            args.dataset) == 0:
        curr_model_state_dict = model.state_dict()
        if args.arch == 'custom_vgg':
            state_dict = model_zoo.load_url(model_urls['vgg16_bn'])
            for name, param in state_dict.items():
                if 'classifier' not in name:
                    curr_model_state_dict['module.' + name].copy_(param)
            curr_model_state_dict['module.features.45.weight'].copy_(
                state_dict['classifier.0.weight'])
            curr_model_state_dict['module.features.45.bias'].copy_(
                state_dict['classifier.0.bias'])
            curr_model_state_dict['module.features.48.weight'].copy_(
                state_dict['classifier.3.weight'])
            curr_model_state_dict['module.features.48.bias'].copy_(
                state_dict['classifier.3.bias'])
            if args.dataset == 'imagenet':
                curr_model_state_dict['module.classifiers.0.weight'].copy_(
                    state_dict['classifier.6.weight'])
                curr_model_state_dict['module.classifiers.0.bias'].copy_(
                    state_dict['classifier.6.bias'])
        elif args.arch == 'resnet50':
            state_dict = model_zoo.load_url(model_urls['resnet50'])
            for name, param in state_dict.items():
                if 'fc' not in name:
                    curr_model_state_dict['module.' + name].copy_(param)
            if args.dataset == 'imagenet':
                curr_model_state_dict['module.classifiers.0.weight'].copy_(
                    state_dict['fc.weight'])
                curr_model_state_dict['module.classifiers.0.bias'].copy_(
                    state_dict['fc.bias'])
        else:
            print(
                "Currently, we didn't define the mapping of {} between imagenet pretrained weight and our model"
                .format(args.arch))
            sys.exit(5)

    if not masks:
        for name, module in model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(
                    module, nl.SharableLinear):
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                mask = mask.cuda()
                masks[name] = mask
    else:
        # when we expand network, we need to allocate new masks
        NEED_ADJUST_MASK = False
        for name, module in model.named_modules():
            if isinstance(module, nl.SharableConv2d):
                if masks[name].size(1) < module.weight.data.size(1):
                    assert args.mode == 'finetune'
                    NEED_ADJUST_MASK = True
                elif masks[name].size(1) > module.weight.data.size(1):
                    assert args.mode == 'inference'
                    NEED_ADJUST_MASK = True

        if NEED_ADJUST_MASK:
            if args.mode == 'finetune':
                for name, module in model.named_modules():
                    if isinstance(module, nl.SharableConv2d):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].
                             size(1), :, :].copy_(masks[name])
                        masks[name] = mask
                    elif isinstance(module, nl.SharableLinear):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].size(1)].copy_(
                            masks[name])
                        masks[name] = mask
            elif args.mode == 'inference':
                for name, module in model.named_modules():
                    if isinstance(module, nl.SharableConv2d):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        mask = mask.cuda()
                        mask[:, :, :, :].copy_(
                            masks[name][:mask.size(0), :mask.size(1), :, :])
                        masks[name] = mask
                    elif isinstance(module, nl.SharableLinear):
                        mask = torch.ByteTensor(
                            module.weight.data.size()).fill_(0)
                        mask = mask.cuda()
                        mask[:, :].copy_(
                            masks[name][:mask.size(0), :mask.size(1)])
                        masks[name] = mask

    if args.dataset not in shared_layer_info:

        shared_layer_info[args.dataset] = {
            'bias': {},
            'bn_layer_running_mean': {},
            'bn_layer_running_var': {},
            'bn_layer_weight': {},
            'bn_layer_bias': {},
            'piggymask': {}
        }

        piggymasks = {}
        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(
                        module, nl.SharableLinear):
                    piggymasks[name] = torch.zeros_like(masks['module.' +
                                                              name],
                                                        dtype=torch.float32)
                    piggymasks[name].fill_(0.01)
                    piggymasks[name] = Parameter(piggymasks[name])
                    module.piggymask = piggymasks[name]
    else:
        piggymasks = shared_layer_info[args.dataset]['piggymask']
        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(
                        module, nl.SharableLinear):
                    module.piggymask = piggymasks[name]

    shared_layer_info[args.dataset][
        'network_width_multiplier'] = args.network_width_multiplier

    if 'cropped' in args.dataset:
        train_loader = dataset.train_loader_cropped(args.train_path,
                                                    args.batch_size)
        val_loader = dataset.val_loader_cropped(args.val_path,
                                                args.val_batch_size)
    else:
        train_loader = dataset.train_loader(args.train_path, args.batch_size)
        val_loader = dataset.val_loader(args.val_path, args.val_batch_size)

    # if we are going to save checkpoint in other folder, then we recalculate the starting epoch
    if args.save_folder != args.load_folder:
        start_epoch = 0
    else:
        start_epoch = resume_from_epoch

    curr_prune_step = begin_prune_step = start_epoch * len(train_loader)
    end_prune_step = curr_prune_step + args.pruning_interval * len(
        train_loader)

    manager = Manager(args, model, shared_layer_info, masks, train_loader,
                      val_loader, begin_prune_step, end_prune_step)

    if args.mode == 'inference':
        manager.load_checkpoint_only_for_evaluate(resume_from_epoch,
                                                  resume_folder)
        manager.validate(resume_from_epoch - 1)
        return

    lr = args.lr
    lr_mask = args.lr_mask
    # update all layers
    named_params = dict(model.named_parameters())
    params_to_optimize_via_SGD = []
    named_of_params_to_optimize_via_SGD = []
    masks_to_optimize_via_Adam = []
    named_of_masks_to_optimize_via_Adam = []

    for name, param in named_params.items():
        if 'classifiers' in name:
            if '.{}.'.format(model.module.datasets.index(
                    args.dataset)) in name:
                params_to_optimize_via_SGD.append(param)
                named_of_params_to_optimize_via_SGD.append(name)
            continue
        elif 'piggymask' in name:
            masks_to_optimize_via_Adam.append(param)
            named_of_masks_to_optimize_via_Adam.append(name)
        else:
            params_to_optimize_via_SGD.append(param)
            named_of_params_to_optimize_via_SGD.append(name)

    optimizer_network = optim.SGD(params_to_optimize_via_SGD,
                                  lr=lr,
                                  weight_decay=0.0,
                                  momentum=0.9,
                                  nesterov=True)
    optimizers = Optimizers()
    optimizers.add(optimizer_network, lr)

    if masks_to_optimize_via_Adam:
        optimizer_mask = optim.Adam(masks_to_optimize_via_Adam, lr=lr_mask)
        optimizers.add(optimizer_mask, lr_mask)

    manager.load_checkpoint(optimizers, resume_from_epoch, resume_folder)
    """Performs training."""
    curr_lrs = []
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            curr_lrs.append(param_group['lr'])
            break

    if args.jsonfile is None or not os.path.isfile(args.jsonfile):
        sys.exit(3)
    with open(args.jsonfile, 'r') as jsonfile:
        json_data = json.load(jsonfile)
        baseline_acc = float(json_data[args.dataset])

    if args.mode == 'prune':
        if args.dataset != 'imagenet':
            history_best_avg_val_acc_when_prune = 0.0
            #history_best_avg_val_acc_when_prune = baseline_acc - 0.005
        else:
            if 'vgg' in args.arch:
                baseline_acc = 0.7336
                history_best_avg_val_acc_when_prune = baseline_acc - 0.005
            elif 'resnet50' in args.arch:
                baseline_acc = 0.7616
                history_best_avg_val_acc_when_prune = baseline_acc - 0.005
            else:
                print('Something is wrong')
                exit(1)

        stop_prune = True

        if 'gradual_prune' in args.load_folder and args.save_folder == args.load_folder:
            if args.dataset == 'imagenet':
                args.epochs = 10 + resume_from_epoch
            else:
                args.epochs = 20 + resume_from_epoch
        logging.info('')
        logging.info('Before pruning: ')
        logging.info('Sparsity range: {} -> {}'.format(args.initial_sparsity,
                                                       args.target_sparsity))
        manager.validate(start_epoch - 1)
        logging.info('')

    elif args.mode == 'finetune':
        manager.pruner.make_finetuning_mask()

        if args.dataset == 'imagenet':
            manager.validate(0)
            manager.save_checkpoint(optimizers, 0, args.save_folder)
            return

        history_best_avg_val_acc = 0.0
        num_epochs_that_criterion_does_not_get_better = 0
        times_of_decaying_learning_rate = 0

    for epoch_idx in range(start_epoch, args.epochs):
        avg_train_acc, curr_prune_step = manager.train(optimizers, epoch_idx,
                                                       curr_lrs,
                                                       curr_prune_step)
        avg_val_acc = manager.validate(epoch_idx)

        if args.mode == 'prune' and (epoch_idx + 1) >= (
                args.pruning_interval + start_epoch) and (
                    avg_val_acc > history_best_avg_val_acc_when_prune):
            stop_prune = False
            history_best_avg_val_acc_when_prune = avg_val_acc
            if args.save_folder is not None:
                paths = os.listdir(args.save_folder)
                if paths and '.pth.tar' in paths[0]:
                    for checkpoint_file in paths:
                        os.remove(
                            os.path.join(args.save_folder, checkpoint_file))
            else:
                print('Something is wrong! Block the program with pdb')
                pdb.set_trace()

            manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)

        if args.mode == 'finetune':

            if avg_val_acc > history_best_avg_val_acc:
                num_epochs_that_criterion_does_not_get_better = 0
                if args.save_folder is not None:
                    paths = os.listdir(args.save_folder)
                    if paths and '.pth.tar' in paths[0]:
                        for checkpoint_file in paths:
                            os.remove(
                                os.path.join(args.save_folder,
                                             checkpoint_file))
                else:
                    print('Something is wrong! Block the program with pdb')
                    pdb.set_trace()

                history_best_avg_val_acc = avg_val_acc
                manager.save_checkpoint(optimizers, epoch_idx,
                                        args.save_folder)
            else:
                num_epochs_that_criterion_does_not_get_better += 1

            if times_of_decaying_learning_rate >= 3:
                print()
                print(
                    "times_of_decaying_learning_rate reach {}, stop training".
                    format(times_of_decaying_learning_rate))
                break

            if num_epochs_that_criterion_does_not_get_better >= 5:
                times_of_decaying_learning_rate += 1
                num_epochs_that_criterion_does_not_get_better = 0
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']
                print()
                print("continously {} epochs doesn't get higher acc, "
                      "decay learning rate by multiplying 0.1".format(
                          num_epochs_that_criterion_does_not_get_better))

                if times_of_decaying_learning_rate == 1 and len(
                        optimizers.lrs) == 2:
                    for param_group in optimizers[1].param_groups:
                        param_group['lr'] *= 0.2
                    curr_lrs[1] = param_group['lr']

    print('-' * 16)

    if args.pruning_ratio_to_acc_record_file:
        json_data = {}
        if os.path.isfile(args.pruning_ratio_to_acc_record_file):
            with open(args.pruning_ratio_to_acc_record_file, 'r') as json_file:
                json_data = json.load(json_file)

    if args.mode == 'finetune' and not args.test_piggymask:
        if args.pruning_ratio_to_acc_record_file:
            json_data[0.0] = round(history_best_avg_val_acc, 4)
            with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                json.dump(json_data, json_file)

        if history_best_avg_val_acc - baseline_acc > -0.005:  # TODO
            #json_data = {}
            #json_data['acc_before_prune'] = '{:.4f}'.format(history_best_avg_val_acc)
            #with open(args.tmp_benchmark_file, 'w') as jsonfile:
            #    json.dump(json_data, jsonfile)
            pass
        else:
            print("It's time to expand the Network")
            print('Auto expand network')
            sys.exit(2)

        if manager.pruner.calculate_curr_task_ratio() == 0.0:
            print(
                'There is no left space in convolutional layer for curr task, so needless to prune'
            )
            sys.exit(5)

    elif args.mode == 'prune':
        #        if stop_prune:
        #            print('Acc too low, stop pruning.')
        #            sys.exit(4)
        if args.pruning_ratio_to_acc_record_file:
            json_data[args.target_sparsity] = round(
                history_best_avg_val_acc_when_prune, 4)
            with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                json.dump(json_data, json_file)
Exemple #6
0
def main():
    """Do stuff."""
    args = parser.parse_args()
    if args.save_folder and not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        args.cuda = False

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.benchmark = True

    # If set > 0, will resume training from a given checkpoint.
    resume_from_epoch = 0
    resume_folder = args.load_folder
    for try_epoch in range(200, 0, -1):
        if os.path.exists(
                args.checkpoint_format.format(save_folder=resume_folder,
                                              epoch=try_epoch)):
            resume_from_epoch = try_epoch
            break

    if args.restore_epoch:
        resume_from_epoch = args.restore_epoch

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    if resume_from_epoch:
        filepath = args.checkpoint_format.format(save_folder=resume_folder,
                                                 epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        if 'shared_layer_info' in checkpoint_keys:
            shared_layer_info = checkpoint['shared_layer_info']
        else:
            shared_layer_info = {}
    else:
        dataset_history = []
        dataset2num_classes = {}
        masks = {}
        shared_layer_info = {}

    if args.arch == 'resnet50':
        model = packnet_models.__dict__[args.arch](
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes)
    elif 'vgg' in args.arch:
        model = packnet_models.__dict__[args.arch](
            pretrained=args.use_imagenet_pretrained,
            dataset_history=dataset_history,
            dataset2num_classes=dataset2num_classes)
    else:
        print('Error!')
        sys.exit(0)

    # Add and set the model dataset
    model.add_dataset(args.dataset, args.num_classes)
    model.set_dataset(args.dataset)

    # Move model to GPU
    model = nn.DataParallel(model)
    model = model.cuda()

    # For datasets whose image_size is 224 and also the first task
    if args.use_imagenet_pretrained and model.module.datasets.index(
            args.dataset) == 0:
        curr_model_state_dict = model.state_dict()
        if args.arch == 'vgg16_bn':
            state_dict = model_zoo.load_url(model_urls['vgg16_bn'])
            curr_model_state_dict = model.state_dict()
            for name, param in state_dict.items():
                if 'classifier' not in name:
                    curr_model_state_dict['module.' + name].copy_(param)
            curr_model_state_dict['module.features.45.weight'].copy_(
                state_dict['classifier.0.weight'])
            curr_model_state_dict['module.features.45.bias'].copy_(
                state_dict['classifier.0.bias'])
            curr_model_state_dict['module.features.48.weight'].copy_(
                state_dict['classifier.3.weight'])
            curr_model_state_dict['module.features.48.bias'].copy_(
                state_dict['classifier.3.bias'])
            if args.dataset == 'imagenet':
                curr_model_state_dict['module.classifiers.0.weight'].copy_(
                    state_dict['classifier.6.weight'])
                curr_model_state_dict['module.classifiers.0.bias'].copy_(
                    state_dict['classifier.6.bias'])
        elif args.arch == 'resnet50':
            state_dict = model_zoo.load_url(model_urls['resnet50'])
            for name, param in state_dict.items():
                if 'fc' not in name:
                    curr_model_state_dict['module.' + name].copy_(param)
            if args.dataset == 'imagenet':
                curr_model_state_dict['module.classifiers.0.weight'].copy_(
                    state_dict['fc.weight'])
                curr_model_state_dict['module.classifiers.0.bias'].copy_(
                    state_dict['fc.bias'])
        else:
            print(
                "Currently, we didn't define the mapping of {} between imagenet pretrained weight and our model"
                .format(args.arch))
            sys.exit(5)

    if not masks:
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                if 'cuda' in module.weight.data.type():
                    mask = mask.cuda()
                masks[name] = mask

    if args.dataset not in shared_layer_info:
        shared_layer_info[args.dataset] = {
            'conv_bias': {},
            'bn_layer_running_mean': {},
            'bn_layer_running_var': {},
            'bn_layer_weight': {},
            'bn_layer_bias': {},
            'fc_bias': {}
        }

    if 'cropped' in args.dataset:
        train_loader = dataset.train_loader_cropped(args.train_path,
                                                    args.batch_size)
        val_loader = dataset.val_loader_cropped(args.val_path,
                                                args.val_batch_size)
    else:
        train_loader = dataset.train_loader(args.train_path, args.batch_size)
        val_loader = dataset.val_loader(args.val_path, args.val_batch_size)

    # if we are going to save checkpoint in other folder, then we recalculate the starting epoch
    if args.save_folder != args.load_folder:
        start_epoch = 0
    else:
        start_epoch = resume_from_epoch

    manager = Manager(args, model, shared_layer_info, masks, train_loader,
                      val_loader)

    if args.mode == 'inference':
        manager.load_checkpoint_for_inference(resume_from_epoch, resume_folder)
        manager.validate(resume_from_epoch - 1)
        return

    lr = args.lr
    # update all layers
    named_params = dict(model.named_parameters())
    params_to_optimize_via_SGD = []
    named_of_params_to_optimize_via_SGD = []

    for name, param in named_params.items():
        if 'classifiers' in name:
            if '.{}.'.format(model.module.datasets.index(
                    args.dataset)) in name:
                params_to_optimize_via_SGD.append(param)
                named_of_params_to_optimize_via_SGD.append(name)
            continue
        else:
            params_to_optimize_via_SGD.append(param)
            named_of_params_to_optimize_via_SGD.append(name)

    # here we must set weight decay to 0.0,
    # because the weight decay strategy in build-in step() function will change every weight elem in the tensor,
    # which will hurt previous tasks' accuracy. (Instead, we do weight decay ourself in the `prune.py`)
    optimizer_network = optim.SGD(params_to_optimize_via_SGD,
                                  lr=lr,
                                  weight_decay=0.0,
                                  momentum=0.9,
                                  nesterov=True)

    optimizers = Optimizers()
    optimizers.add(optimizer_network, lr)

    manager.load_checkpoint(optimizers, resume_from_epoch, resume_folder)
    """Performs training."""
    curr_lrs = []
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            curr_lrs.append(param_group['lr'])
            break

    if start_epoch != 0:
        curr_best_accuracy = manager.validate(start_epoch - 1)
    elif args.mode == 'prune':
        print()
        print('Sparsity ratio: {}'.format(args.one_shot_prune_perc))
        print('Before pruning: ')
        with open(args.jsonfile, 'r') as jsonfile:
            json_data = json.load(jsonfile)
            baseline_acc = float(json_data[args.dataset])
        # baseline_acc = manager.validate(start_epoch-1)
        print('Execute one shot pruning ...')
        manager.one_shot_prune(args.one_shot_prune_perc)
    else:
        curr_best_accuracy = 0.0

    if args.mode == 'finetune':
        manager.pruner.make_finetuning_mask()
        if args.dataset == 'imagenet':
            avg_val_acc = manager.validate(0)
            manager.save_checkpoint(optimizers, 0, args.save_folder)
            if args.logfile:
                json_data = {}
                if os.path.isfile(args.logfile):
                    with open(args.logfile) as json_file:
                        json_data = json.load(json_file)

                json_data[args.dataset] = '{:.4f}'.format(avg_val_acc)

                with open(args.logfile, 'w') as json_file:
                    json.dump(json_data, json_file)
            return

        history_best_val_acc = 0.0
        num_epochs_that_criterion_does_not_get_better = 0
        times_of_decaying_learning_rate = 0

    for epoch_idx in range(start_epoch, args.epochs):
        avg_train_acc = manager.train(optimizers, epoch_idx, curr_lrs)
        avg_val_acc = manager.validate(epoch_idx)

        if args.mode == 'finetune':
            if avg_val_acc > history_best_val_acc:
                num_epochs_that_criterion_does_not_get_better = 0
                history_best_val_acc = avg_val_acc
                if args.save_folder is not None:
                    paths = os.listdir(args.save_folder)
                    if paths and '.pth.tar' in paths[0]:
                        for checkpoint_file in paths:
                            os.remove(
                                os.path.join(args.save_folder,
                                             checkpoint_file))
                else:
                    print('Something is wrong! Block the program with pdb')
                    pdb.set_trace()

                manager.save_checkpoint(optimizers, epoch_idx,
                                        args.save_folder)

                if args.logfile:
                    json_data = {}
                    if os.path.isfile(args.logfile):
                        with open(args.logfile) as json_file:
                            json_data = json.load(json_file)

                    json_data[args.dataset] = '{:.4f}'.format(avg_val_acc)

                    with open(args.logfile, 'w') as json_file:
                        json.dump(json_data, json_file)
            else:
                num_epochs_that_criterion_does_not_get_better += 1

            if times_of_decaying_learning_rate >= 3:
                print()
                print(
                    "times_of_decaying_learning_rate reach {}, stop training".
                    format(times_of_decaying_learning_rate))

                break

            if num_epochs_that_criterion_does_not_get_better >= 10:
                times_of_decaying_learning_rate += 1
                num_epochs_that_criterion_does_not_get_better = 0
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']
                print()
                print("continously {} epochs doesn't get higher acc, "
                      "decay learning rate by multiplying 0.1".format(
                          num_epochs_that_criterion_does_not_get_better))

        if args.mode == 'prune':
            if epoch_idx + 1 == 40:
                for param_group in optimizers[0].param_groups:
                    param_group['lr'] *= 0.1
                curr_lrs[0] = param_group['lr']

    if args.mode == 'prune':
        if avg_train_acc > 0.97 and (avg_val_acc - baseline_acc) >= -0.01:
            manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
        else:
            print('Pruning too much!')
    elif args.mode == 'finetune':
        if avg_train_acc < 0.97:
            print('Cannot prune any more!')

    print('-' * 16)
Exemple #7
0
def main():
    """Do stuff."""
    args = FLAGS.parse_args()

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    # Load the required model.
    if args.arch == 'vgg16':
        model = net.ModifiedVGG16(mask_init=args.mask_init,
                                  mask_scale=args.mask_scale,
                                  threshold_fn=args.threshold_fn,
                                  original=args.no_mask)
    elif args.arch == 'vgg16bn':
        model = net.ModifiedVGG16BN(mask_init=args.mask_init,
                                    mask_scale=args.mask_scale,
                                    threshold_fn=args.threshold_fn,
                                    original=args.no_mask)
    elif args.arch == 'resnet50':
        model = net.ModifiedResNet(mask_init=args.mask_init,
                                   mask_scale=args.mask_scale,
                                   threshold_fn=args.threshold_fn,
                                   original=args.no_mask)
    elif args.arch == 'densenet121':
        model = net.ModifiedDenseNet(mask_init=args.mask_init,
                                     mask_scale=args.mask_scale,
                                     threshold_fn=args.threshold_fn,
                                     original=args.no_mask)
    elif args.arch == 'resnet50_diff':
        assert args.source
        model = net.ResNetDiffInit(args.source,
                                   mask_init=args.mask_init,
                                   mask_scale=args.mask_scale,
                                   threshold_fn=args.threshold_fn,
                                   original=args.no_mask)
    else:
        raise ValueError('Architecture %s not supported.' % (args.arch))

    # Add and set the model dataset.
    model.add_dataset(args.dataset, args.num_outputs)
    model.set_dataset(args.dataset)
    if args.cuda:
        model = model.cuda()

    # Initialize with weight based method, if necessary.
    if not args.no_mask and args.mask_init == 'weight_based_1s':
        print('Are you sure you want to try this?')
        assert args.mask_scale_gradients == 'none'
        assert not args.mask_scale
        for idx, module in enumerate(model.shared.modules()):
            if 'ElementWise' in str(type(module)):
                weight_scale = module.weight.data.abs().mean()
                module.mask_real.data.fill_(weight_scale)

    # Create the manager object.
    manager = Manager(args, model)

    # Perform necessary mode operations.
    if args.mode == 'finetune':
        if args.no_mask:
            # No masking will be done, used to run baselines of
            # Classifier-Only and Individual Networks.
            # Checks.
            assert args.lr and args.lr_decay_every
            assert not args.lr_mask and not args.lr_mask_decay_every
            assert not args.lr_classifier and not args.lr_classifier_decay_every
            print('No masking, running baselines.')

            # Get optimizer with correct params.
            if args.finetune_layers == 'all':
                params_to_optimize = model.parameters()
            elif args.finetune_layers == 'classifier':
                for param in model.shared.parameters():
                    param.requires_grad = False
                params_to_optimize = model.classifier.parameters()

            # optimizer = optim.Adam(params_to_optimize, lr=args.lr)
            optimizer = optim.SGD(params_to_optimize,
                                  lr=args.lr,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
            optimizers = Optimizers(args)
            optimizers.add(optimizer, args.lr, args.lr_decay_every)
            manager.train(args.finetune_epochs,
                          optimizers,
                          save=True,
                          savename=args.save_prefix)
        else:
            # Masking will be done.
            # Checks.
            assert not args.lr and not args.lr_decay_every
            assert args.lr_mask and args.lr_mask_decay_every
            assert args.lr_classifier and args.lr_classifier_decay_every
            print('Performing masking.')

            optimizer_masks = optim.Adam(model.shared.parameters(),
                                         lr=args.lr_mask)
            optimizer_classifier = optim.Adam(model.classifier.parameters(),
                                              lr=args.lr_classifier)

            optimizers = Optimizers(args)
            optimizers.add(optimizer_masks, args.lr_mask,
                           args.lr_mask_decay_every)
            optimizers.add(optimizer_classifier, args.lr_classifier,
                           args.lr_classifier_decay_every)
            manager.train(args.finetune_epochs,
                          optimizers,
                          save=True,
                          savename=args.save_prefix)
    elif args.mode == 'eval':
        # Just run the model on the eval set.
        manager.eval()
    elif args.mode == 'check':
        manager.check()