def main():
    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
    check_config(cfg)

    check_version()

    main_arch = cfg.architecture

    # Use CPU for exporting inference model instead of GPU
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)

    model = create(main_arch)

    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    with fluid.program_guard(infer_prog, startup_prog):
        with fluid.unique_name.guard():
            inputs_def = cfg['TestReader']['inputs_def']
            inputs_def['use_dataloader'] = False
            feed_vars, _ = model.build_inputs(**inputs_def)
            # postprocess not need in exclude_nms, exclude NMS in exclude_nms mode
            test_fetches = model.test(feed_vars, exclude_nms=FLAGS.exclude_nms)
    infer_prog = infer_prog.clone(True)
    check_py_func(infer_prog)

    exe.run(startup_prog)
    checkpoint.load_params(exe, infer_prog, cfg.weights)

    dump_infer_config(FLAGS, cfg)
    save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
Esempio n. 2
0
def main():
    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
    check_config(cfg)
    check_version()

    main_arch = cfg.architecture

    # Use CPU for exporting inference model instead of GPU
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)

    model = create(main_arch)

    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    with fluid.program_guard(infer_prog, startup_prog):
        with fluid.unique_name.guard():
            inputs_def = cfg['TestReader']['inputs_def']
            inputs_def['use_dataloader'] = False
            feed_vars, _ = model.build_inputs(**inputs_def)
            test_fetches = model.test(feed_vars)
    infer_prog = infer_prog.clone(True)

    pruned_params = FLAGS.pruned_params
    assert (
        FLAGS.pruned_params is not None
    ), "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
    pruned_params = FLAGS.pruned_params.strip().split(",")
    logger.info("pruned params: {}".format(pruned_params))
    pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(",")]
    logger.info("pruned ratios: {}".format(pruned_ratios))
    assert (len(pruned_params) == len(pruned_ratios)
            ), "The length of pruned params and pruned ratios should be equal."
    assert (pruned_ratios > [0] * len(pruned_ratios) and
            pruned_ratios < [1] * len(pruned_ratios)
            ), "The elements of pruned ratios should be in range (0, 1)."

    base_flops = flops(infer_prog)
    pruner = Pruner()
    infer_prog, _, _ = pruner.prune(
        infer_prog,
        fluid.global_scope(),
        params=pruned_params,
        ratios=pruned_ratios,
        place=place,
        only_graph=True)
    pruned_flops = flops(infer_prog)
    logger.info("pruned FLOPS: {}".format(
        float(base_flops - pruned_flops) / base_flops))

    exe.run(startup_prog)
    checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)

    dump_infer_config(FLAGS, cfg)
    save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)