Esempio n. 1
0
    def update_model(model):
        """ """
        # if not resume:
        #   return model
        if dataset.startswith('cifar'):
            # apply mask right now
            utils.apply_mask(model, excludes=excludes_for_applying_mask)

        return model
Esempio n. 2
0
    def update_model(model):
        """ """
        if not resume:
            return model

        if apply_mask:
            # apply mask right now
            utils.apply_mask(model)

        return model
Esempio n. 3
0
def main():
    parser = create_cli_parser()
    args = parser.parse_args()

    set_random_seed(42)

    cfg = GumiConfig(**vars(args))
    cfg = patch_dataset(cfg)

    logging.info("==> Initializing ModelPruner")
    model_pruner = ModelPruner(cfg)

    model = model_pruner.load_model(update_state_dict_fn=update_state_dict)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Replacing Conv2d in model by MaskConv2d ...")
    utils.apply_mask(model)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Pruning model ...")
    get_num_groups = group_utils.create_get_num_groups_fn(
        G=cfg.num_groups, MCPG=cfg.mcpg, group_cfg=cfg.group_cfg)

    model_pruner.prune(
        model,
        get_num_groups_fn=get_num_groups,
        perm=cfg.perm,
        no_weight=cfg.no_weight,
        num_iters=cfg.num_sort_iters,
        keep_mask=False,
    )
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Fine-tuning model ...")
    model_pruner.fine_tune(model, record_top5=False)
    model_pruner.validate(model, record_top5=False)

    logging.info("==> Exporting the model ...")
    model = GroupExporter.export(model)
    logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset),
    ))

    logging.info("==> Saving exported model ...")
    torch.save(model, os.path.join(cfg.checkpoint, "pruned.pth.tar"))
Esempio n. 4
0
def main():
    """ Main """
    logging.info("==> Initializing ModelPruner ...")
    model_pruner = ModelPruner(args)

    # load model
    logging.info("==> Loading model ...")

    update_model_fn = create_update_model_fn(
        args.arch,
        args.dataset,
        args.pretrained,
        args.resume,
        apply_mask=args.apply_mask,
        condensenet=args.condensenet,
    )
    model = model_pruner.load_model(
        update_model_fn=update_model_fn,
        update_state_dict_fn=create_update_state_dict_fn(
            no_mask=not args.resume, condensenet=args.condensenet),
        fine_tune=args.fine_tune,
    )
    # evaluate the performance of the model in the beginning
    if not args.skip_validation:
        logging.info("==> Validating the loaded model ...")
        loss1, acc1 = model_pruner.validate(model)

    #################################################
    # Pruning                                       #
    #                                               #
    #################################################
    if not args.apply_mask:
        # NOTE: we have not applied mask yet
        # # major pruning function
        logging.info("==> Replacing Conv2d in model by MaskConv2d ...")
        # TODO - duplicated with update_model_fn?
        # not quite, if not resume the model won't be updated
        utils.apply_mask(model)

        if not args.skip_validation:
            logging.info("==> Validating the masked model ...")
            loss2, acc2 = model_pruner.validate(model)
            assert torch.allclose(acc1, acc2)

    # run pruning (update the content of mask)
    logging.info("==> Pruning model ...")
    if not args.skip_prune:
        get_num_groups = create_get_num_groups_fn(G=args.num_groups,
                                                  MCPG=args.mcpg,
                                                  group_cfg=args.group_cfg)

        logging.debug("Pruning configuration:")
        logging.debug("PERM:        {}".format(args.perm))
        logging.debug("NS:          {}".format(args.num_sort_iters))
        logging.debug("No weight:   {}".format(args.no_weight))
        logging.debug("Keep mask:   {}".format(args.keep_mask))
        logging.debug("")

        model_pruner.prune(
            model,
            get_num_groups_fn=get_num_groups,
            perm=args.perm,
            no_weight=args.no_weight,
            num_iters=args.num_sort_iters,
            keep_mask=args.keep_mask,
        )

        if not args.skip_validation:
            logging.info("==> Validating the pruned model ...")
            loss3, acc3 = model_pruner.validate(model)

    else:
        logging.info("Pruning has been skipped, you have the original model.")

    #################################################
    # Fine-tuning                                   #
    #                                               #
    #################################################
    logging.info("==> Fine-tuning the pruned model ...")
    if args.train_from_scratch:
        logging.info("==> Training the pruned topology from scratch ...")

        # reset weight parameters
        # TODO: refactorize
        for name, mod in model.named_modules():
            if hasattr(
                    mod,
                    "weight") and len(mod.weight.shape) >= 2:  # re-initialize
                torch.nn.init.kaiming_normal_(mod.weight, nonlinearity="relu")
                # if hasattr(mod, 'G'):
                #   mod.weight.data.mul_(mod.G)
            if hasattr(mod, "bias") and mod.bias is not None:
                mod.bias.data.fill_(0.0)

    if not args.skip_fine_tune:
        model_pruner.fine_tune(model)

        if not args.skip_validation:
            logging.info("==> Validating the fine-tuned pruned model ...")
            loss4, acc4 = model_pruner.validate(model)
            logging.info(
                "==> Final validation accuracy of the pruned model: {:.2f}%".
                format(acc4))
    else:
        logging.info("Fine-tuning has been skipped.")

    #################################################
    # Export                                        #
    #                                               #
    #################################################
    logging.info("==> Exporting the model ...")
    model = GroupExporter.export(model)
    if use_cuda:
        model.cuda()
    logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format(
        model_utils.get_model_num_params(model),
        utils.get_model_num_ops(model, args.dataset),
    ))

    if not args.skip_validation:
        logging.info("==> Validating the exported pruned model ...")
        loss5, acc5 = model_pruner.validate(model)
        logging.info(
            "==> Final validation accuracy of the exported model: {:.2f}%".
            format(acc5))
Esempio n. 5
0
 def update_model(model):
     """ """
     utils.apply_mask(model, use_cuda=False)
     return model