Exemplo n.º 1
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        input_masks = get_input_masks(nx_node, nx_graph)

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = None
        nx_node['accept_pruned_input'] = False
Exemplo n.º 2
0
    def all_inputs_from_convs(self, nx_node, nx_graph, graph):
        """
        Return whether all input sources of nx_node is convolutions or not
        :param nx_node: node to determine it's sources
        :param nx_graph:  networkx graph to work with
        :param graph:  NNCF graph to work with
        """
        inputs = [u for u, _ in nx_graph.in_edges(nx_node['key'])]
        input_masks = get_input_masks(nx_node, nx_graph)

        for i, inp in enumerate(inputs):
            # If input has mask ->  it went from convolution (source of this node is a convolution)
            if input_masks[i] is not None:
                continue
            nncf_input_node = graph._nx_node_to_nncf_node(nx_graph.nodes[inp])
            source_nodes = get_sources_of_node(
                nncf_input_node, graph,
                Convolution.get_all_op_aliases() +
                StopMaskForwardOps.get_all_op_aliases() +
                Input.get_all_op_aliases())
            sources_types = [
                node.op_exec_context.operator_name for node in source_nodes
            ]
            if any([
                    t in sources_types
                    for t in StopMaskForwardOps.get_all_op_aliases()
            ]):
                return False
        return True
Exemplo n.º 3
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        is_depthwise = False
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        # In case of group convs we can't prune by output filters
        if node_module.groups != 1:
            if node_module.weight.size(1) == 1:
                # Depthwise case
                is_depthwise = True
                output_mask = input_masks[0]
            else:
                accept_pruned_input = False
                output_mask = None

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input
        nx_node['is_depthwise'] = is_depthwise
Exemplo n.º 4
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        input_masks = get_input_masks(nx_node, nx_graph)

        nx_node['input_masks'] = input_masks
        if input_masks[0] is not None:
            assert all(
                [torch.allclose(input_masks[0], mask) for mask in input_masks])
        nx_node['output_mask'] = input_masks[0]
Exemplo n.º 5
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input