Exemplo n.º 1
0
def init_dump(arch):
    """Dumps pretrained model in required format."""
    if arch == 'vgg16':
        model = net.ModifiedVGG16()
    elif arch == 'vgg16bn':
        model = net.ModifiedVGG16BN()
    elif arch == 'resnet50':
        model = net.ModifiedResNet()
    elif arch == 'densenet121':
        model = net.ModifiedDenseNet()
    else:
        raise ValueError('Architecture type not supported.')

    previous_masks = {}
    for module_idx, module in enumerate(model.shared.modules()):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            mask = torch.ByteTensor(module.weight.data.size()).fill_(1)
            if 'cuda' in module.weight.data.type():
                mask = mask.cuda()
            previous_masks[module_idx] = mask
    torch.save({
        'dataset2idx': {'imagenet': 1},
        'previous_masks': previous_masks,
        'model': model,
    }, '../checkpoints/imagenet/%s.pt' % (arch))
Exemplo n.º 2
0
    def check(self):
        """Makes sure that the trained model weights match those of the pretrained model."""
        print('Making sure filter weights have not changed.')
        if self.args.arch == 'vgg16':
            pretrained = net.ModifiedVGG16(original=True)
        elif self.args.arch == 'vgg16bn':
            pretrained = net.ModifiedVGG16BN(original=True)
        elif self.args.arch == 'resnet50':
            pretrained = net.ModifiedResNet(original=True)
        elif self.args.arch == 'densenet121':
            pretrained = net.ModifiedDenseNet(original=True)
        elif self.args.arch == 'resnet50_diff':
            pretrained = net.ResNetDiffInit(self.args.source, original=True)
        else:
            raise ValueError('Architecture %s not supported.' %
                             (self.args.arch))

        for module, module_pretrained in zip(self.model.shared.modules(),
                                             pretrained.shared.modules()):
            if 'ElementWise' in str(type(module)) or 'BatchNorm' in str(
                    type(module)):
                weight = module.weight.data.cpu()
                weight_pretrained = module_pretrained.weight.data.cpu()
                # Using small threshold of 1e-8 for any floating point inconsistencies.
                # Note that threshold per element is even smaller as the 1e-8 threshold
                # is for sum of absolute differences.
                assert (weight - weight_pretrained).abs().sum() < 1e-8, \
                    'module %s failed check' % (module)
                if module.bias is not None:
                    bias = module.bias.data.cpu()
                    bias_pretrained = module_pretrained.bias.data.cpu()
                    assert (bias - bias_pretrained).abs().sum() < 1e-8
                if 'BatchNorm' in str(type(module)):
                    rm = module.running_mean.cpu()
                    rm_pretrained = module_pretrained.running_mean.cpu()
                    assert (rm - rm_pretrained).abs().sum() < 1e-8
                    rv = module.running_var.cpu()
                    rv_pretrained = module_pretrained.running_var.cpu()
                    assert (rv - rv_pretrained).abs().sum() < 1e-8
        print('Passed checks...')
Exemplo n.º 3
0
def main():
    """Do stuff."""
    args = FLAGS.parse_args()

    if args.mode == 'pack':
        assert args.packlist and args.maskloc
        dataset2masks = {}
        dataset2classifiers = {}
        net_type = None

        # Location to output stats.
        fout = open(args.maskloc[:-2] + 'txt', 'w')

        # Load models one by one and store their masks.
        fin = open(args.packlist, 'r')
        counter = 1
        for idx, line in enumerate(fin):
            if not line or not line.strip() or line[0] == '#':
                continue
            dataset, loadname = line.split(':')
            loadname = loadname.strip()

            # Can't have same dataset twice.
            if dataset in dataset2masks:
                ValueError('Repeated datasets as input...')
            print('Loading model #%d for dataset "%s"' % (counter, dataset))
            counter += 1
            ckpt = torch.load(loadname)
            model = ckpt['model']
            # Ensure all inputs are for same model type.
            if net_type is None:
                net_type = str(type(model))
            else:
                assert net_type == str(type(model)), '%s != %s' % (
                    net_type, str(type(model)))

            # Gather masks and store in dictionary.
            fout.write('Dataset: %s\n' % (dataset))
            total_params, neg_params, zerod_params = [], [], []
            masks = {}
            for module_idx, module in enumerate(model.shared.modules()):
                if 'ElementWise' in str(type(module)):
                    mask = module.threshold_fn(module.mask_real)
                    mask = mask.data.cpu()

                    # Make sure mask values are in {0, 1} or {-1, 0, 1}.
                    num_zero = mask.eq(0).sum()
                    num_one = mask.eq(1).sum()
                    num_mone = mask.eq(-1).sum()
                    total = mask.numel()
                    threshold_type = module.threshold_fn.__class__.__name__
                    if threshold_type == 'Binarizer':
                        assert num_mone == 0
                        assert num_zero + num_one == total
                    elif threshold_type == 'Ternarizer':
                        assert num_mone + num_zero + num_one == total
                    masks[module_idx] = mask.type(torch.ByteTensor)

                    # Count total and zerod out params.
                    total_params.append(total)
                    zerod_params.append(num_zero)
                    neg_params.append(num_mone)
                    fout.write('%d\t%.2f%%\t%.2f%%\n' % (
                        module_idx,
                        neg_params[-1] / total_params[-1] * 100,
                        zerod_params[-1] / total_params[-1] * 100))
            print('Check Passed: Masks only have binary/ternary values.')
            dataset2masks[dataset] = masks
            dataset2classifiers[dataset] = model.classifier

            fout.write('Total -1: %d/%d = %.2f%%\n' % (
                sum(neg_params), sum(total_params), sum(neg_params) / sum(total_params) * 100))
            fout.write('Total 0: %d/%d = %.2f%%\n' % (
                sum(zerod_params), sum(total_params), sum(zerod_params) / sum(total_params) * 100))
            fout.write('-' * 20 + '\n')

        # Clean up and save masks to file.
        fin.close()
        fout.close()
        torch.save({
            'dataset2masks': dataset2masks,
            'dataset2classifiers': dataset2classifiers,
        }, args.maskloc)

    elif args.mode == 'eval':
        assert args.arch and args.maskloc and args.dataset

        # Set default train and test path if not provided as input.
        utils.set_dataset_paths(args)

        # Load masks and classifier for this dataset.
        info = torch.load(args.maskloc)
        if args.dataset not in info['dataset2masks']:
            ValueError('%s not found in masks.' % (args.dataset))
        masks = info['dataset2masks'][args.dataset]
        classifier = info['dataset2classifiers'][args.dataset]

        # Create the vanilla model and apply masking.
        model = None
        if args.arch == 'vgg16':
            model = net.ModifiedVGG16(original=True)
        elif args.arch == 'vgg16bn':
            model = net.ModifiedVGG16BN(original=True)
        elif args.arch == 'resnet50':
            model = net.ModifiedResNet(original=True)
        elif args.arch == 'densenet121':
            model = net.ModifiedDenseNet(original=True)
        elif args.arch == 'resnet50_diff':
            assert args.source
            model = net.ResNetDiffInit(args.source, original=True)
        model.eval()

        print('Applying masks.')
        for module_idx, module in enumerate(model.shared.modules()):
            if module_idx in masks:
                mask = masks[module_idx]
                module.weight.data[mask.eq(0)] = 0
                module.weight.data[mask.eq(-1)] *= -1
        print('Applied masks.')

        # Override model.classifier with saved one.
        model.add_dataset(args.dataset, classifier.weight.size(0))
        model.set_dataset(args.dataset)
        model.classifier = classifier
        if args.cuda:
            model = model.cuda()

        # Create the manager and run eval.
        manager = Manager(args, model)
        manager.eval()
Exemplo n.º 4
0
def main():
    """Do stuff."""
    args = FLAGS.parse_args()

    # Set default train and test path if not provided as input.
    utils.set_dataset_paths(args)

    # Load the required model.
    if args.arch == 'vgg16':
        model = net.ModifiedVGG16(mask_init=args.mask_init,
                                  mask_scale=args.mask_scale,
                                  threshold_fn=args.threshold_fn,
                                  original=args.no_mask)
    elif args.arch == 'vgg16bn':
        model = net.ModifiedVGG16BN(mask_init=args.mask_init,
                                    mask_scale=args.mask_scale,
                                    threshold_fn=args.threshold_fn,
                                    original=args.no_mask)
    elif args.arch == 'resnet50':
        model = net.ModifiedResNet(mask_init=args.mask_init,
                                   mask_scale=args.mask_scale,
                                   threshold_fn=args.threshold_fn,
                                   original=args.no_mask)
    elif args.arch == 'densenet121':
        model = net.ModifiedDenseNet(mask_init=args.mask_init,
                                     mask_scale=args.mask_scale,
                                     threshold_fn=args.threshold_fn,
                                     original=args.no_mask)
    elif args.arch == 'resnet50_diff':
        assert args.source
        model = net.ResNetDiffInit(args.source,
                                   mask_init=args.mask_init,
                                   mask_scale=args.mask_scale,
                                   threshold_fn=args.threshold_fn,
                                   original=args.no_mask)
    else:
        raise ValueError('Architecture %s not supported.' % (args.arch))

    # Add and set the model dataset.
    model.add_dataset(args.dataset, args.num_outputs)
    model.set_dataset(args.dataset)
    if args.cuda:
        model = model.cuda()

    # Initialize with weight based method, if necessary.
    if not args.no_mask and args.mask_init == 'weight_based_1s':
        print('Are you sure you want to try this?')
        assert args.mask_scale_gradients == 'none'
        assert not args.mask_scale
        for idx, module in enumerate(model.shared.modules()):
            if 'ElementWise' in str(type(module)):
                weight_scale = module.weight.data.abs().mean()
                module.mask_real.data.fill_(weight_scale)

    # Create the manager object.
    manager = Manager(args, model)

    # Perform necessary mode operations.
    if args.mode == 'finetune':
        if args.no_mask:
            # No masking will be done, used to run baselines of
            # Classifier-Only and Individual Networks.
            # Checks.
            assert args.lr and args.lr_decay_every
            assert not args.lr_mask and not args.lr_mask_decay_every
            assert not args.lr_classifier and not args.lr_classifier_decay_every
            print('No masking, running baselines.')

            # Get optimizer with correct params.
            if args.finetune_layers == 'all':
                params_to_optimize = model.parameters()
            elif args.finetune_layers == 'classifier':
                for param in model.shared.parameters():
                    param.requires_grad = False
                params_to_optimize = model.classifier.parameters()

            # optimizer = optim.Adam(params_to_optimize, lr=args.lr)
            optimizer = optim.SGD(params_to_optimize,
                                  lr=args.lr,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
            optimizers = Optimizers(args)
            optimizers.add(optimizer, args.lr, args.lr_decay_every)
            manager.train(args.finetune_epochs,
                          optimizers,
                          save=True,
                          savename=args.save_prefix)
        else:
            # Masking will be done.
            # Checks.
            assert not args.lr and not args.lr_decay_every
            assert args.lr_mask and args.lr_mask_decay_every
            assert args.lr_classifier and args.lr_classifier_decay_every
            print('Performing masking.')

            optimizer_masks = optim.Adam(model.shared.parameters(),
                                         lr=args.lr_mask)
            optimizer_classifier = optim.Adam(model.classifier.parameters(),
                                              lr=args.lr_classifier)

            optimizers = Optimizers(args)
            optimizers.add(optimizer_masks, args.lr_mask,
                           args.lr_mask_decay_every)
            optimizers.add(optimizer_classifier, args.lr_classifier,
                           args.lr_classifier_decay_every)
            manager.train(args.finetune_epochs,
                          optimizers,
                          save=True,
                          savename=args.save_prefix)
    elif args.mode == 'eval':
        # Just run the model on the eval set.
        manager.eval()
    elif args.mode == 'check':
        manager.check()