コード例 #1
0
def test_graph_kws_auto_quant(kws_graph, kws_sounds):
    G = create_graph(kws_graph, opts={"load_tensors": True})
    G.add_dimensions()
    G.adjust_order()
    get_std_match_group().match(G)
    G.add_dimensions()
    stats_collector = ActivationStatsCollector()
    for input_file in kws_sounds:
        data = import_data(input_file, offset=0, divisor=256, nptype='int16')
        stats_collector.collect_stats(G, [data])
    astats = stats_collector.reduce_stats()
    stats_collector = FilterStatsCollector()
    fstats = stats_collector.collect_stats(G)
    quantizer = SimpleQuantizer(astats, fstats, force_width=16)
    qrecs = quantizer.quantize(G)
    G.quantization = qrecs
コード例 #2
0
def test_simple_quantization(mnist_graph, mnist_images):
    G = create_graph(mnist_graph, opts={"load_tensors": True})
    G.add_dimensions()
    input_tensor = import_data(mnist_images[0],
                               height=28,
                               width=28,
                               offset=0,
                               divisor=255)
    input_tensor = input_tensor.reshape((28, 28, 1))
    stats_collector = ActivationStatsCollector()
    stats_collector.collect_stats(G, [input_tensor])
    astats = stats_collector.reduce_stats()
    stats_collector = FilterStatsCollector()
    fstats = stats_collector.collect_stats(G)
    quantizer = SymmetricQuantizer(astats, fstats, force_width=8)
    qrecs = quantizer.quantize(G)
    assert len(qrecs) == 11  # One more for saved quantizer
    report = QuantizationReporter().report(G, qrecs)
    renderer = TextTableRenderer(maxwidth=200)
    print(report.render(renderer))
コード例 #3
0
    def do_aquant(self, args: argparse.Namespace):
        """
Attempt to calculate quantization for graph using one or more sample imput files."""
        self._check_graph()
        input_args = self._get_input_args(args)
        processed_input = False
        stats_collector = ActivationStatsCollector()
        for file_per_input in glob_input_files(args.input_files,
                                               self.G.num_inputs):
            LOG.info("input file %s", file_per_input)
            processed_input = True
            data = [
                import_data(input_file, **input_args)
                for input_file in file_per_input
            ]
            stats_collector.collect_stats(self.G, data)
        if not processed_input:
            self.perror("No imput files found")
            return
        astats = stats_collector.reduce_stats()
        if args.scheme == 'SQ8':
            quantizer = MultQuantizer(
                astats,
                8,
                quantized_dimension=args.quant_dimension,
                narrow_weights=not args.no_narrow_weights)
        else:
            stats_collector = FilterStatsCollector()
            fstats = stats_collector.collect_stats(self.G)
            quantizer = SymmetricQuantizer(astats,
                                           fstats,
                                           force_width=args.force_width,
                                           min_qsnr=args.qsnr)
        qrecs = quantizer.quantize(self.G)
        self.G.quantization = qrecs
        if args.scheme == 'SQ8':
            concats_matcher = EqualizeSymmetricMultiplicativeQuantivedConcats()
            concats_matcher.match(self.G, set_identity=False)
            softmax_qrec_matcher = PropagateSoftmaxSymQrec()
            softmax_qrec_matcher.match(self.G, set_identity=False)
        LOG.info("Quantization set. Use qshow command to see it.")
コード例 #4
0
    def do_stats(self, args: argparse.Namespace):
        """
Display statistics on weights and biases"""
        self._check_graph()
        fmt = ('tab' if args.output is None else args.output['fmt'])
        if args.detailed:
            stats_collector = FilterDetailedStatsCollector()
            stats = stats_collector.collect_stats(self.G)
            tab = FilterDetailedStatsReporter().report(self.G, stats)
        else:
            step_idx = args.step
            if step_idx is not None:
                if len(step_idx) == 1:
                    step_idx = step_idx[0]
                else:
                    step_idx = tuple(step_idx)
            stats_collector = FilterStatsCollector()
            stats = stats_collector.collect_stats(self.G, step_idx=step_idx)
            tab = FilterStatsReporter(do_totals=(fmt != "csv"), threshold=args.qsnr, step_idx=step_idx)\
                .report(self.G, stats)
        output_table(tab, args)
コード例 #5
0
ファイル: conftest.py プロジェクト: hasetz/gap_sdk
def save_state(temp_dir, width, fusions=False, adjust=False):
    file_name = os.path.join(temp_dir, "state_file")
    G = create_graph(MNIST_GRAPH, opts={"load_tensors":True})
    G.add_dimensions()
    if adjust:
        G.adjust_order()
    if fusions:
        get_std_match_group().match(G)
        G.add_dimensions()
    stats_collector = ActivationStatsCollector()
    for input_file in MNIST_IMAGES:
        data = import_data(input_file, offset=0, divisor=255)
        if not adjust:
            data = data.reshape((28, 28, 1))
        stats_collector.collect_stats(G, [data])
    astats = stats_collector.reduce_stats()
    stats_collector = FilterStatsCollector()
    fstats = stats_collector.collect_stats(G)
    quantizer = SimpleQuantizer(astats, fstats, force_width=width)
    qrecs = quantizer.quantize(G)
    G.quantization = qrecs
    dump_state(G, include_parameters=True, state_path=file_name)
    return file_name