Пример #1
0
    def run_opt(self,
                model,
                max_num_params,
                max_num_ops,
                strategy="MAX_COST",
                **kwargs):
        """ Run the actual optimization.
    
      Model is already loaded in this case. All the modules that should be grouped 
      are already replaced by MaskConv2d.

      We will update the G value in each module in place.
    """
        logging.info("Finding the optimal group configuration ...")
        logging.info("Max # param: {:.2f} M".format(max_num_params))
        logging.info("Max # ops:   {:.2f} M".format(max_num_ops))

        c_dict = OrderedDict()  # candidate dictionary
        state = OrderedDict()  # G index for each group

        logging.debug("==> Constructing the candidate dictionary ...")
        for name, mod in model.named_modules():
            if isinstance(mod, MaskConv2d):
                Gs, costs = list(self.find_group_candidates(mod, **kwargs))
                c_dict[name] = (Gs, costs, self.get_mod_norm(mod))
                state[name] = len(Gs) - 1  # the last group

        # Initial state
        self.assign_state_to_model(model, state, c_dict)

        logging.debug(
            "Initial state: {:.2f} M params {:.2f} M ops cost: {:.2f}%".format(
                model_utils.get_model_num_params(model),
                utils.get_model_num_ops(model, self.args.dataset),
                (1 - self.get_cost(state, c_dict)) * 100,
            ))

        # Now, depends on the strategy, we will update the group setting
        if strategy == "ILP":
            # HACK - need to refactorise the code
            # state = self.ilp_solve(c_dict)
            raise NotImplementedError(
                "ILP as a optimisation strategy is not implemented")
        elif strategy == "MAX_COST":
            state, cost = self.max_cost_solver(model, state, c_dict,
                                               max_num_params, max_num_ops)
        else:
            raise ValueError("Cannot recognise strategy: {}".format(strategy))

        # finalise
        self.assign_state_to_model(model, state, c_dict)
        num_params = model_utils.get_model_num_params(model)
        num_ops = utils.get_model_num_ops(model, self.args.dataset)

        logging.debug(
            "Final state - params: {:.2f} M  ops: {:6.2f} M cost: {:.2f}".
            format(num_params, num_ops, (1 - cost) * 100))
Пример #2
0
    def max_cost_solver(self, model, state, c_dict, max_num_params,
                        max_num_ops):
        num_params = model_utils.get_model_num_params(model)
        num_ops = utils.get_model_num_ops(model, self.args.dataset)
        cost = self.get_cost(state, c_dict)

        if num_params > max_num_params or num_ops > max_num_ops:
            raise ValueError(
                'The maximal constraints are too restrictive: params {:.2f} ops {:.2f}.'
                .format(num_params, num_ops))

        step = 0
        while num_params <= max_num_params and num_ops <= max_num_ops:
            # we can have a try
            costs = []
            for name, gi in state.items():
                # try reduce
                if gi == 0:
                    continue

                state_ = state.copy()
                state_[name] = gi - 1  # this will definitely increase the cost
                cost_ = self.get_cost(state_, c_dict)
                costs.append((name, cost_))

            if not costs:
                logging.debug('No more choices for MAX_COST, exiting ...')
                break

            # find the max cost update
            max_cost = max(costs, key=lambda k: k[1])
            # print(costs)
            # print(max_cost)

            # update the state
            state[max_cost[0]] -= 1
            cost = max_cost[1]

            # update
            self.assign_state_to_model(model, state, c_dict)
            if use_cuda:  # post-fix model placing
                model.cuda()
            num_params = model_utils.get_model_num_params(model)
            num_ops = utils.get_model_num_ops(model, self.args.dataset)

            step += 1
            logging.debug(
                '[{:4d}] Current state -  params: {:.2f} M ops: {:6.2f} M cost: {:.2f}%'
                .format(step, num_params, num_ops, (1 - cost) * 100))

        return state, cost
Пример #3
0
 def test_ctor(self):
     """ Build the model """
     # CondenseNet-86 (G=4)
     model = condensenet.CondenseNet([14, 14, 14], [8, 16, 32], groups=4)
     self.assertEqual(model_utils.get_num_conv2d_layers(model), 86)
     self.assertAlmostEqual(model_utils.get_model_num_params(model),
                            0.52,
                            places=2)
Пример #4
0
    def update_model_by_group_cfg(self, model, g_cfg):
        """ Post update the G value of each GroupConv2d. """
        for name, mod in model.named_modules():
            if name in g_cfg:
                assert isinstance(mod, GroupConv2d)
                G = g_cfg[name]["G"]
                mod.setup_conv2d(G)

        logging.info("Updated G of each model by group_cfg.")
        logging.info("    Total params: {:.2f}M".format(
            model_utils.get_model_num_params(model)))
Пример #5
0
  def create_model(self, args, **kwargs):
    """ Create model only. """
    num_classes = model_utils.get_num_classes(args)
    model = models.__dict__[args.arch](num_classes=num_classes, **kwargs)
    model = torch.nn.DataParallel(model).cuda()

    logging.info('Created model by arch={} and num_classes={} on GPU={}'.format(
        args.arch, num_classes, args.gpu_id))
    logging.info('    Total params: {:.2f}M'.format(
        model_utils.get_model_num_params(model)))

    return model
Пример #6
0
    def create_model(self):
        """ Create a new model """
        model = models.__dict__[self.args.arch](num_classes=self.num_classes)
        model = torch.nn.DataParallel(model).cuda()

        logging.debug(
            "Created model by arch={} and num_classes={} on GPU={}".format(
                self.args.arch, self.num_classes, self.args.gpu_id))
        logging.debug("    Total params: {:.2f}M".format(
            model_utils.get_model_num_params(model)))

        return model
Пример #7
0
    def test_ctor(self):
        """ Build BasicBlocks and BottleneckBlocks, and the model. """
        # the network model itself
        model = densenet.DenseNet(depth=40,
                                  Block=densenet.BasicBlock,
                                  growth_rate=12,
                                  compression_rate=1.0,
                                  mask=True,
                                  num_classes=100)
        num_params = model_utils.get_model_num_params(model)

        self.assertAlmostEqual(num_params, 1.06, places=1)  # around 1.7
        self.assertEqual(model_utils.get_num_conv2d_layers(model), 40)
Пример #8
0
def main():
    parser = create_cli_parser()
    args = parser.parse_args()

    set_random_seed(42)

    cfg = GumiConfig(**vars(args))
    cfg = patch_dataset(cfg)

    logging.info("==> Initializing ModelPruner")
    model_pruner = ModelPruner(cfg)

    model = model_pruner.load_model(update_state_dict_fn=update_state_dict)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Replacing Conv2d in model by MaskConv2d ...")
    utils.apply_mask(model)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Pruning model ...")
    get_num_groups = group_utils.create_get_num_groups_fn(
        G=cfg.num_groups, MCPG=cfg.mcpg, group_cfg=cfg.group_cfg)

    model_pruner.prune(
        model,
        get_num_groups_fn=get_num_groups,
        perm=cfg.perm,
        no_weight=cfg.no_weight,
        num_iters=cfg.num_sort_iters,
        keep_mask=False,
    )
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Fine-tuning model ...")
    model_pruner.fine_tune(model, record_top5=False)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Exporting the model ...")
    model = GroupExporter.export(model)
    logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset),
    ))

    logging.info("==> Saving exported model ...")
    torch.save(model, os.path.join(cfg.checkpoint, "pruned.pth.tar"))
Пример #9
0
def main():
    use_cuda = not args.cpu
    print("==> Loading model {}".format(args.arch))
    model = utils.load_model(args.arch, "imagenet", use_cuda=use_cuda, pretrained=True)

    print("==> Loading group config {}".format(args.group_cfg))
    with open(args.group_cfg, "r") as f:
        group_cfg = json.load(f)

    print("==> Updating model ...")
    model = update_model(model, group_cfg, use_cuda=use_cuda)

    print(model)

    print(
        "==> Model size: {:.2f} M ops: {:.2f} M".format(
            model_utils.get_model_num_params(model),
            utils.get_model_num_ops(model, "imagenet"),
        )
    )

    torch.save(model, os.path.join(os.path.dirname(args.resume), "peng.pth.tar"))
Пример #10
0
    def profile(self, iters, use_cuda=False):
        """ Run the profile function """
        groups = [1, 2, 4, 8, 16, 32, 64]
        sparse = (True, False)

        data = []
        columns = ['G', 'is_sparse', 'num_params', 'num_ops', 'time']

        for g in groups:
            for is_sparse in sparse:
                if g != 1 and not is_sparse:
                    continue

                model = MobileNet(groups=g,
                                  sparse=is_sparse,
                                  num_channels=1000)
                update_sparse_weight(model, block=True)

                num_params = model_utils.get_model_num_params(model)
                num_ops = utils.get_model_num_ops(model, self.args.dataset)

                x = torch.rand((1, 3, 224, 224))

                if use_cuda:
                    x = x.cuda()
                    model.cuda()

                logging.info(
                    '==> Profiling G={} sparse={} for {} iters ...'.format(
                        g, is_sparse, iters))
                elapsed = self.profile_case(x, model, iters=iters)

                logging.info(
                    '\t# params. {:.2f}M # ops {:.2f}M Time elapsed: {:.2f} ms '
                    .format(num_params, num_ops, elapsed))

                data.append([g, is_sparse, num_params, num_ops, elapsed])

        return pd.DataFrame(data, columns=columns)
Пример #11
0
    def test_ctor(self):
        """ Build BasicBlocks and BottleneckBlocks, and the model. """

        # check masked BasicBlock
        basic_block = preresnet.BasicBlock(32, 32, mask=True)
        self.assertIsInstance(basic_block.conv1, MaskConv2d)
        self.assertIsInstance(basic_block.conv2, MaskConv2d)

        # masked BottleneckBlock
        bottleneck_block = preresnet.BottleneckBlock(32, 32, mask=True)
        self.assertIsInstance(bottleneck_block.conv1, MaskConv2d)
        self.assertIsInstance(bottleneck_block.conv2, MaskConv2d)
        self.assertIsInstance(bottleneck_block.conv3, MaskConv2d)

        # the network model itself
        model = preresnet.PreResNet(164,
                                    mask=True,
                                    block_name='BottleneckBlock',
                                    num_classes=100)
        num_params = model_utils.get_model_num_params(model)

        self.assertAlmostEqual(num_params, 1.70, places=1)  # around 1.7
        self.assertEqual(model_utils.get_num_conv2d_layers(model), 164)
Пример #12
0
def main():
    """ Main """
    logging.info("==> Initializing ModelPruner ...")
    model_pruner = ModelPruner(args)

    # load model
    logging.info("==> Loading model ...")

    update_model_fn = create_update_model_fn(
        args.arch,
        args.dataset,
        args.pretrained,
        args.resume,
        apply_mask=args.apply_mask,
        condensenet=args.condensenet,
    )
    model = model_pruner.load_model(
        update_model_fn=update_model_fn,
        update_state_dict_fn=create_update_state_dict_fn(
            no_mask=not args.resume, condensenet=args.condensenet),
        fine_tune=args.fine_tune,
    )
    # evaluate the performance of the model in the beginning
    if not args.skip_validation:
        logging.info("==> Validating the loaded model ...")
        loss1, acc1 = model_pruner.validate(model)

    #################################################
    # Pruning                                       #
    #                                               #
    #################################################
    if not args.apply_mask:
        # NOTE: we have not applied mask yet
        # # major pruning function
        logging.info("==> Replacing Conv2d in model by MaskConv2d ...")
        # TODO - duplicated with update_model_fn?
        # not quite, if not resume the model won't be updated
        utils.apply_mask(model)

        if not args.skip_validation:
            logging.info("==> Validating the masked model ...")
            loss2, acc2 = model_pruner.validate(model)
            assert torch.allclose(acc1, acc2)

    # run pruning (update the content of mask)
    logging.info("==> Pruning model ...")
    if not args.skip_prune:
        get_num_groups = create_get_num_groups_fn(G=args.num_groups,
                                                  MCPG=args.mcpg,
                                                  group_cfg=args.group_cfg)

        logging.debug("Pruning configuration:")
        logging.debug("PERM:        {}".format(args.perm))
        logging.debug("NS:          {}".format(args.num_sort_iters))
        logging.debug("No weight:   {}".format(args.no_weight))
        logging.debug("Keep mask:   {}".format(args.keep_mask))
        logging.debug("")

        model_pruner.prune(
            model,
            get_num_groups_fn=get_num_groups,
            perm=args.perm,
            no_weight=args.no_weight,
            num_iters=args.num_sort_iters,
            keep_mask=args.keep_mask,
        )

        if not args.skip_validation:
            logging.info("==> Validating the pruned model ...")
            loss3, acc3 = model_pruner.validate(model)

    else:
        logging.info("Pruning has been skipped, you have the original model.")

    #################################################
    # Fine-tuning                                   #
    #                                               #
    #################################################
    logging.info("==> Fine-tuning the pruned model ...")
    if args.train_from_scratch:
        logging.info("==> Training the pruned topology from scratch ...")

        # reset weight parameters
        # TODO: refactorize
        for name, mod in model.named_modules():
            if hasattr(
                    mod,
                    "weight") and len(mod.weight.shape) >= 2:  # re-initialize
                torch.nn.init.kaiming_normal_(mod.weight, nonlinearity="relu")
                # if hasattr(mod, 'G'):
                #   mod.weight.data.mul_(mod.G)
            if hasattr(mod, "bias") and mod.bias is not None:
                mod.bias.data.fill_(0.0)

    if not args.skip_fine_tune:
        model_pruner.fine_tune(model)

        if not args.skip_validation:
            logging.info("==> Validating the fine-tuned pruned model ...")
            loss4, acc4 = model_pruner.validate(model)
            logging.info(
                "==> Final validation accuracy of the pruned model: {:.2f}%".
                format(acc4))
    else:
        logging.info("Fine-tuning has been skipped.")

    #################################################
    # Export                                        #
    #                                               #
    #################################################
    logging.info("==> Exporting the model ...")
    model = GroupExporter.export(model)
    if use_cuda:
        model.cuda()
    logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset),
    ))

    if not args.skip_validation:
        logging.info("==> Validating the exported pruned model ...")
        loss5, acc5 = model_pruner.validate(model)
        logging.info(
            "==> Final validation accuracy of the exported model: {:.2f}%".
            format(acc5))
Пример #13
0
def main():
    """ Load the model and prune it """
    logging.info('==> Initializing GroupModelExporter ...')
    exporter = GroupModelExporter(args)

    # load model
    logging.info('==> Loading model ...')
    model = exporter.load_model(
        use_cuda=True,
        data_parallel=False,
        update_model_fn=create_update_model_fn(),
        update_state_dict_fn=create_update_state_dict_fn())

    # prune it to setup each MaskConv2d
    get_num_groups_fn = utils.create_get_num_groups_fn(
        G=args.num_groups,
        MCPG=args.mcpg,
        group_cfg=args.group_cfg,
        use_cuda=True,
        data_parallel=False)

    exporter.prune(model,
                   get_num_groups_fn=get_num_groups_fn,
                   keep_mask=True,
                   use_cuda=True)
    for name, mod in model.named_modules():
        if isinstance(mod, MaskConv2d):
            print(name, mod.G)

    # export
    logging.info('==> Exporting the model ...')
    model = GroupExporter.export(model,
                                 use_cuda=True,
                                 mm=args.mm,
                                 sparse=args.sparse,
                                 std=args.std,
                                 min_sparse_channels=args.min_sparse_channels)

    # evaluate
    if args.val:
        exporter.validate(model)
    # move back to CPU
    model.cpu()
    # print(model)
    logging.debug('Total params: {:.2f}M FLOPS: {:.2f}M'.format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset)))

    logging.debug('==> Densify model ...')
    utils.apply_dense(model)

    # save model
    if not args.onnx:
        suffix = ''
        if args.mm:
            suffix += '_mm'
        if args.sparse:
            suffix += '_sparse'
        if args.std:
            suffix += '_std'
        if args.min_sparse_channels > 0:
            suffix += '_min{}'.format(args.min_sparse_channels)

        fp = os.path.join(os.path.dirname(args.resume),
                          'gconv{}.pth.tar'.format(suffix))
        logging.info('==> Saving the model to {} ...'.format(fp))
        torch.save(model, fp)
    else:
        fp = os.path.join(os.path.dirname(args.resume), 'gconv.onnx')
        logging.info('==> Exporting model to ONNX {} ...'.format(fp))

        if args.dataset in utils.IMAGENET_DATASETS:
            x = torch.randn(1, 3, 224, 224, requires_grad=True)
        else:
            raise RuntimeError('Do not support {}'.format(args.dataset))

        onnx = torch.onnx._export(model, x, fp, export_params=True)
Пример #14
0
def main():
  use_cuda = not args.cpu

  if args.pretrained:
    logging.info('==> Loading pre-trained model ...')
    model = utils.load_model(
        args.arch, args.dataset, pretrained=args.pretrained, use_cuda=use_cuda)
  else:
    logging.info('==> Loading GConv model directly from Pickle ...')
    model = torch.load(args.resume)
    if not use_cuda:
      model.cpu()

    model.eval()
    logging.info('==> Sparsify model ...')
    utils.apply_sparse(model)
    print(model)

    logging.debug('Total params: {:.2f}M FLOPS: {:.2f}M'.format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset)))

  if args.dataset in utils.IMAGENET_DATASETS:
    x = torch.rand((args.test_batch, 3, 224, 224))
  else:
    x = torch.rand((args.test_batch, 3, 32, 32))

  if not use_cuda:
    x = x.cpu()

  # setup the input and model
  if use_cuda:
    x = x.cuda()
    model.cuda()

  logging.info('==> Dry running ...')
  dry_run_iters = 10 if not use_cuda else 100
  for _ in range(dry_run_iters):
    y = model.forward(x)

  logging.info('==> Print profiling info ...')
  with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
    y = model.forward(x)
  print(prof)

  # start timing
  logging.info('==> Start timing ...')
  if use_cuda:
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(args.iters):
      y = model.forward(x)
    end.record()

    # synchronize
    torch.cuda.synchronize()

    elapsed = start.elapsed_time(end)
  else:
    start = time.time()
    for _ in range(args.iters):
      y = model.forward(x)
    end = time.time()
    elapsed = (end - start) * 1e3

  print('Elapsed time: {:10.2f} sec (total) {:6.2f} ms (per run) {:6.2f} FPS.'.
        format(elapsed * 1e-3, elapsed / args.iters,
               args.iters * args.test_batch / elapsed * 1e3))
Пример #15
0
def load_model(arch,
               dataset,
               resume=None,
               pretrained=False,
               update_model_fn=None,
               update_state_dict_fn=None,
               use_cuda=True,
               fine_tune=False,
               data_parallel=True,
               checkpoint_file_name='checkpoint.pth.tar',
               **kwargs):
    """ Load a model.
  
    You can either load a CIFAR model from gumi.models
    or an ImageNet model from torchvision.models

    Additional parameters in kwargs are passed only to cifar models.

    You can use update_model_fn and update_state_dict_fn
    to configure those models undesirable.

    NOTE: fine_tune won't control whether to run replace_classifier.
  """
    # construct the model
    num_classes = get_num_classes(dataset)
    if dataset.startswith('cifar'):
        model = cifar_models.__dict__[arch](num_classes=num_classes, **kwargs)
    elif dataset in IMAGENET_DATASETS:
        # NOTE: when creating this model, all its contents are
        # already initialised. Won't go to the resume branch.
        if arch in imagenet_models.__dict__:
            model = imagenet_models.__dict__[arch](pretrained=pretrained)
        else:
            model = custom_imagenet_models.__dict__[arch](**kwargs)

        replace_classifier(arch, model, dataset, fine_tune=fine_tune)

    logging.debug('Total params: {:.2f}M FLOPS: {:.2f}M'.format(
        model_utils.get_model_num_params(model),
        get_model_num_ops(model, dataset)))

    # update model if required
    if update_model_fn:
        logging.debug('update_model_fn is provided.')
        model = update_model_fn(model)

    if resume:  # load from checkpoint
        if pretrained:
            raise ValueError(
                'You cannot specify pretrained to True and resume not None.')

        assert isinstance(resume, str)

        # update the resume if it points to a directory
        if os.path.isdir(resume):
            resume = os.path.join(resume, checkpoint_file_name)
            logging.debug(
                'Resume was given as a directory, updated to: {}'.format(
                    resume))

        # now resume should be a valid file.
        assert os.path.isfile(resume)

        checkpoint = torch.load(resume)  # load

        # get the state dict
        state_dict = checkpoint['state_dict']
        if update_state_dict_fn:
            state_dict = update_state_dict_fn(state_dict)

        # initialize model
        model.load_state_dict(state_dict, strict=not fine_tune)

    if use_cuda:
        if data_parallel:
            model = torch.nn.DataParallel(model)
        model = model.cuda()

    return model
Пример #16
0
    def get_model_size(self, model):
        """ Return model number of parameters and ops """
        num_params = model_utils.get_model_num_params(model)
        num_ops = utils.get_model_num_ops(model, self.args.dataset)

        return num_params, num_ops