Beispiel #1
0
def test_graph_calc(mnist_graph, mnist_images):
    temp_graph = create_temporary_copy(mnist_graph)
    G = create_graph(temp_graph, opts={"load_tensors":True})
    G.add_dimensions()
    input_tensor = import_data(mnist_images[0], height=28, width=28, divisor=128, offset=-1)
    input_tensor = input_tensor.reshape(28, 28, 1)
    # import data always returns C, H, W. We need H, W, C.

    stats_collector = ActivationStatsCollector()
    stats_collector.collect_stats(G, [input_tensor])
    astats = stats_collector.reduce_stats()

    stats_collector = FilterStatsCollector()
    fstats = stats_collector.collect_stats(G)

    quantizer = SimpleQuantizer(astats, fstats, force_width=8)
    qrecs = quantizer.quantize(G)

    G.quantization = qrecs

    dump_state(G)

    G = load_state(temp_graph)

    for k, v in G.quantization.items():
        assert v == qrecs[k], "problem with " + str(k)


    assert G.quantization == qrecs
Beispiel #2
0
    def do_save_state(self, args):
        """
Save the state of the transforms and quantization of the graph.
This state file can be used to generate the model file as part of
a build script. If no argument is given then the state files
will be saved in the same directory as the graph. If a directory is
given then the state files will be saved in it with the graph
basename. If a filename is given, its basename will be used to
save the state files."""
        self._check_graph()
        self._check_quantized()
        gen_opts = {k: self.settings[k] for k in DEFAULT_GEN_OPTS}
        dump_state(self.G, state_path=args.output, extra=gen_opts)
Beispiel #3
0
def save_state(temp_dir, width, fusions=False, adjust=False):
    file_name = os.path.join(temp_dir, "state_file")
    G = create_graph(MNIST_GRAPH, opts={"load_tensors":True})
    G.add_dimensions()
    if adjust:
        G.adjust_order()
    if fusions:
        get_std_match_group().match(G)
        G.add_dimensions()
    stats_collector = ActivationStatsCollector()
    for input_file in MNIST_IMAGES:
        data = import_data(input_file, offset=0, divisor=255)
        if not adjust:
            data = data.reshape((28, 28, 1))
        stats_collector.collect_stats(G, [data])
    astats = stats_collector.reduce_stats()
    stats_collector = FilterStatsCollector()
    fstats = stats_collector.collect_stats(G)
    quantizer = SimpleQuantizer(astats, fstats, force_width=width)
    qrecs = quantizer.quantize(G)
    G.quantization = qrecs
    dump_state(G, include_parameters=True, state_path=file_name)
    return file_name