Ejemplo n.º 1
0
def _get_opu_layer_graph(lgf, cal_data, max_nodes=32, opu_only=False):
    """Return a graph with all the opu ops as output edges, in execution-order."""
    node_filter = node_filters.and_filter(
        node_filters.not_filter(node_filters.op_filter(ops_pb2.UNKNOWN)),
        node_filters.not_filter(node_filters.op_filter(ops_pb2.CONST)))

    for output_edge in lgf.output_edges():
        output_node = lgf.get_node_by_name(output_edge.name)
        if node_filter.matches(output_node, lgf):
            ordered = list(lgf.bfs(output_node, node_filter=node_filter))
            break

    ordered.reverse()
    opu_edges = []

    for i, node_obj in enumerate(ordered):
        if not opu_only or _is_opu(node_obj):
            opu_edges.append(node_obj.outputs[0])
            if len(opu_edges) > max_nodes:
                break

    if not opu_edges:
        raise ValueError("Could not find OPU node")

    return opu_edges
Ejemplo n.º 2
0
 def ignore_nodes_filter(self):
     keep_nodes = node_filters.or_filter(
         node_filters.name_starts_with_filter(
             "FirstStageFeatureExtractor/resnet_v1_50"),
         node_filters.name_starts_with_filter(
             "SecondStageFeatureExtractor/resnet_v1_50"))
     return node_filters.not_filter(keep_nodes)
Ejemplo n.º 3
0
 def ignore_nodes_filter(self):
     keep_nodes = [
         node_filters.name_starts_with_filter(
             "FeatureExtractor/resnet_v1_50"),
         node_filters.name_starts_with_filter(
             "WeightSharedConvolutionalBoxPredictor"),
     ]
     keep_nodes = node_filters.or_filter(*keep_nodes)
     return node_filters.not_filter(keep_nodes)
Ejemplo n.º 4
0
    def prune_graph(self,
                    input_edges=None,
                    output_edges=None,
                    output_node_names=None,
                    include_inputs=True):
        """Returns a new light_graph object."""
        # Inputs and outputs of pruned graph are the same
        input_edges = input_edges or self.input_edges()
        output_edges = output_edges or self.output_edges()
        output_node_names = output_node_names or self.output_node_names()

        # Node filter for input nodes
        input_node_filter = node_filters.and_filter(*[
            node_filters.not_filter(node_filters.name_is_filter(e.name))
            for e in input_edges
        ])

        # Get the root nodes for pruning, include required nodes
        root_nodes = [self.get_node_by_name(e.name) for e in output_edges] + [
            self.get_node_by_name(node_name) for node_name in output_node_names
        ] + [
            self.get_node_by_name(node_name)
            for node_name in self._meta_graph_info.required_nodes
        ]

        # Only keep nodes that the outputs depend on
        nodes = []
        node_names = set()
        for i, root_node in enumerate(root_nodes):
            # Do not use the input node filter for required nodes
            if i < (len(output_edges) + len(output_node_names)):
                node_filter = input_node_filter
            else:
                node_filter = None

            for node in self.bfs(root_node, node_filter=node_filter):
                if node.name not in node_names:
                    nodes.append(node)
                    node_names.add(node.name)

        # Make sure inputs and outputs come from the original graph
        input_edges = [self.get_edge(e.name, e.port) for e in input_edges]
        output_edges = [self.get_edge(e.name, e.port) for e in output_edges]

        # Add input nodes if necessary
        if include_inputs:
            for e in input_edges:
                if e.name in self._node_dict and e.name not in node_names:
                    nodes.append(self._node_dict[e.name])
                    node_names.add(e.name)

        return LightGraph(nodes,
                          input_edges=input_edges,
                          output_edges=output_edges,
                          output_node_names=output_node_names,
                          meta_graph_info=self.meta_graph_info())
Ejemplo n.º 5
0
def add_op_transform(sw_config, graph_type, op, ignore_nodes_filter, tx_name):
    pair = sw_config.filter_transform_map.add()

    filt = node_filters.op_filter(op)
    filt = node_filters.and_filter(
        filt, node_filters.not_filter(ignore_nodes_filter))

    pair.filter.CopyFrom(filt.as_proto())
    pair.transform.graph_type = graph_type
    pair.transform.op = op
    pair.transform.transform_module_name = tx_name
Ejemplo n.º 6
0
def generate_standard_sw_config(
        graph_type,
        num_threads_scales=32,
        activation_scale_quantization_bias_type=common_pb2.QB_NONE,
        weight_quantization_type=common_pb2.QT_PER_COL_PER_TILE,
        weight_quantization_cutoff=0,
        adc_scale_quantization_type=common_pb2.QT_PER_COL_PER_TILE,
        activation_scale_quantization_method=common_pb2.QM_MIN_KL_DIVERGENCE,
        adc_scale_quantization_method=common_pb2.
    QM_MIN_TOTAL_VARIATION_DISTANCE,
        cache_dir="",
        skip_adc_cal=False,
        skip_activation_cal=False,
        activation_scale_num_bins=4096,
        adc_scale_num_bins=4096,
        nodes_to_skip=[],
        float_type=create_dtype(dtypes_pb2.DT_BFLOAT, 16),
        convert_graph_to_debug_mode=False,
        debug_dir="",
        save_hist_html_files=False,
        ops_to_skip=[],
        fold_phasify=True,
        collect_bit_activity=False,
        collect_memory_layout=False,
        ignore_nodes_filter=None,
        ignore_empty_histograms=False,
        num_fine_tuning_epochs=0,
        py_batch_size=0,
        num_py_batches=0,
        use_unsigned_quant_scheme=False,
        quantized_electronic_nodes=[],
        allow_tmem_fall_back=False,
        tile_inputs_for_accumulators=True,
        dep_pc_distance_precision=16,
        num_opu_tiles_precision=16,
        num_batch_tiles_precision=16,
        protos=[],
        no_odd_image_dims_conv2d=False,
        disable_block_sparsity=True,
        hw_cfg=hardware_configs_pb2.VANGUARD):
    """
    Note that nodes_to_skip and ops_to_skip will be added to the
    ignore_node_node filter
    """
    # Default ignore_nodes_filter
    if ignore_nodes_filter is None:
        ignore_nodes_filter = node_filters.not_filter(
            node_filters.true_filter())

    # Update ignore_nodes_filter
    nodes_to_skip_filter = node_filters.or_filter(*[
        node_filters.name_is_filter(node_name) for node_name in nodes_to_skip
    ])
    ops_to_skip_filter = node_filters.or_filter(
        *[node_filters.op_filter(op) for op in ops_to_skip])
    ignore_nodes_filter = node_filters.or_filter(ignore_nodes_filter,
                                                 nodes_to_skip_filter,
                                                 ops_to_skip_filter)

    # Initialize sw config
    sw_config = sw_config_pb2.SoftwareConfig()

    # Transform stages
    sw_config.standard_transform_stages.append(sw_config_pb2.BASE_TRANSFORMS)
    if not skip_activation_cal and graph_type != graph_types_pb2.TFLiteSavedModel:
        sw_config.standard_transform_stages.append(
            sw_config_pb2.ACTIVATION_SCALE_CALIBRATION)
    if not skip_adc_cal:
        sw_config.standard_transform_stages.append(
            sw_config_pb2.ADC_SCALE_CALIBRATION)
    if fold_phasify:
        sw_config.standard_transform_stages.append(
            sw_config_pb2.FOLD_PHASIFY_CONSTANTS)

    sw_config.use_weight_sharing = False

    sw_config.float_type.CopyFrom(float_type)
    sw_config.quantized_electronic_op_precision = 8

    sw_config.activation_scale_quantization_bias_type = \
        activation_scale_quantization_bias_type
    sw_config.weight_quantization_type = weight_quantization_type
    sw_config.weight_quantization_cutoff = weight_quantization_cutoff
    sw_config.adc_scale_quantization_type = adc_scale_quantization_type

    sw_config.activation_scale_quantization_method = activation_scale_quantization_method
    sw_config.adc_scale_quantization_method = adc_scale_quantization_method
    sw_config.ignore_empty_histograms = ignore_empty_histograms

    sw_config.activation_scale_num_bins = activation_scale_num_bins
    sw_config.adc_scale_num_bins = adc_scale_num_bins

    get_standard_filter_transform_map(sw_config, graph_type, hw_cfg,
                                      ignore_nodes_filter)

    sw_config.node_types.opu_nodes.extend([
        lgf_pb2.LNF.matmul.DESCRIPTOR.name, lgf_pb2.LNF.conv2d.DESCRIPTOR.name,
        lgf_pb2.LNF.block_diagonal_depthwise_conv2d.DESCRIPTOR.name,
        lgf_pb2.LNF.distributed_depthwise_conv2d.DESCRIPTOR.name
    ])
    sw_config.node_types.quantized_electronic_nodes.extend(
        quantized_electronic_nodes)

    sw_config.num_threads_scales = num_threads_scales

    sw_config.debug_info.collect_checksums = False
    sw_config.debug_info.debug_dir = debug_dir

    sw_config.max_proto_size = int(1.8e9)

    sw_config.cache_dir = cache_dir

    sw_config.sweep_info.py_batch_size = py_batch_size
    sw_config.sweep_info.num_py_batches = num_py_batches
    sw_config.sweep_info.convert_graph_to_debug_mode = convert_graph_to_debug_mode
    sw_config.sweep_info.save_hist_html_files = save_hist_html_files
    sw_config.sweep_info.collect_bit_activity = collect_bit_activity
    sw_config.sweep_info.collect_memory_layout = collect_memory_layout
    sw_config.sweep_info.num_fine_tuning_epochs = num_fine_tuning_epochs

    sw_config.ignore_nodes_filter.CopyFrom(ignore_nodes_filter.as_proto())

    sw_config.use_unsigned_quant_scheme = use_unsigned_quant_scheme

    _set_default_compiler_params(
        sw_config,
        allow_tmem_fall_back=allow_tmem_fall_back,
        tile_inputs_for_accumulators=tile_inputs_for_accumulators,
        dep_pc_distance_precision=dep_pc_distance_precision,
        num_opu_tiles_precision=num_opu_tiles_precision,
        num_batch_tiles_precision=num_batch_tiles_precision,
        no_odd_image_dims_conv2d=no_odd_image_dims_conv2d)

    set_instruction_formats(sw_config, protos)

    sw_config.disable_block_sparsity = disable_block_sparsity

    sw_config.const_transform_name = CONST_TRANSFORM

    return sw_config
Ejemplo n.º 7
0
    def bfs(self,
            root_node,
            bidirectional=False,
            node_filter=None,
            skip_control_inputs=False):
        """
        Does a BFS on the graph starting at the root_node

        Params:
            root_node: starting node for the BFS
            bidirectional: If False, look at a nodes inputs when doing the BFS and
                discovering new nodes. If True do a bidirectional search, looking at
                a nodes inputs and outputs when discovering new nodes.
            node_filter: If provided, only add nodes to the frontier that match the
                filter with this graph. Note that if the root_node does not match the
                provided filter, no nodes will be returned.
        """
        # Check for unsupported cases
        if bidirectional and skip_control_inputs:
            raise ValueError("Bidirectional BFS is currently unsupported when" +
                             "skipping control inputs")

        # Update node filter with defaults
        default_filter = node_filters.not_filter(
            node_filters.name_starts_with_filter("^"))
        if node_filter is None:
            node_filter = default_filter
        else:
            node_filter = node_filters.and_filter(default_filter, node_filter)

        # Special case when the root_node does not match node_filter
        if not (node_filter.matches(root_node, self)):
            return []

        # BFS
        visited_node_names = {root_node.name}
        current_nodes = [root_node]
        frontier = []
        while current_nodes:
            for parent_node in current_nodes:
                yield self._copy_node(parent_node)

                # Default uses inputs for child nodes
                if skip_control_inputs:
                    # Skip control inputs
                    child_nodes = [
                        self._node_dict[e.name]
                        for e in parent_node.inputs
                        if self.has_node(e.name)
                    ]
                else:
                    # Include control inputs
                    child_nodes = [
                        self._node_dict[n]
                        for n in self._node_to_input_node_names[parent_node.name]
                    ]

                # Bidirectional adds outputs as well, currently always includes
                # control inputs
                if bidirectional:
                    child_nodes += [
                        self._node_dict[n]
                        for n in self._node_to_output_node_names[parent_node.name]
                    ]

                for child_node in child_nodes:
                    if (child_node.name not in visited_node_names
                            and node_filter.matches(child_node,
                                                    self)):
                        visited_node_names.add(child_node.name)
                        frontier.append(child_node)

            current_nodes = frontier
            frontier = []