Exemplo n.º 1
0
def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
                                    mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add subgraph
    """
    shape = tinput.shape

    # Create first Mul & Add operations
    mul1_node = Mul(graph, dict(name="Mul1_", can_be_fused=can_be_fused))
    add1_node = Add(graph, dict(name="Add1_", can_be_fused=can_be_fused))

    mul1_data = Op.create_input_data_node(graph, "data_mul_", np.array(mean))
    add1_data = Op.create_input_data_node(graph, "data_add_", np.array(variance))

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.shape[0] != gamma.value.shape[0]:
        gamma.value.resize(gamma.shape)
        gamma.value.fill(gamma.value[0])

    # Create second Mul & Add
    mul2_node = Mul(graph, dict(name="Mul2_", can_be_fused=can_be_fused))
    add2_node = Add(graph, dict(name="Add2_", can_be_fused=can_be_fused))

    add2_node.create_node_with_data(
        inputs=[mul2_node.create_node_with_data(
            inputs=[add1_node.create_node_with_data(
                inputs=[mul1_node.create_node_with_data(inputs=[tinput, mul1_data]),
                        add1_data]),
                gamma]),
            beta],
        data_nodes=toutput)
Exemplo n.º 2
0
 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
     """
     Adds Normalize layer weights, which are required by Inference Engine, 
     but do not always exist in MXNet model. 
     
     L2Normalization is mapped to Normalize layer
     so we need to generate Normalize weights filled with ones.
     
     Parameters
     ----------
     graph : nx.MultiDiGraph
        Graph with loaded model.
      match : dict
        Patterns which were found in graph structure.
     """
     l2_normalization_node = match['l2_normalization']
     if len(l2_normalization_node.in_nodes()) < 2:
         value = np.full([l2_normalization_node.in_node(0).shape[1]],
                         1.0,
                         dtype=np.float32)
         weights_node = Op.create_input_data_node(
             graph,
             name=l2_normalization_node['name'] + '_weights',
             value=value)
         create_edge(weights_node,
                     l2_normalization_node,
                     out_port=0,
                     in_port=1,
                     edge_attrs={'bin': 'weights'})
Exemplo n.º 3
0
    def apply_mean_value(graph: Graph, input_node: Node,
                         node_mean_scale_values: dict):
        if 'mean' in node_mean_scale_values and node_mean_scale_values[
                'mean'] is not None:
            if all([x == 0 for x in node_mean_scale_values['mean']]):
                return
            out_node = input_node.out_node()
            if not input_node.has_valid('shape'):
                raise Error("Node {} has not valid shape attribute".format(
                    input_node.id))
            input_shape = input_node.shape
            # Create Add node
            graph.remove_edge(input_node.id, out_node.id)

            value = np.array(node_mean_scale_values['mean']) * (-1)

            add_node = Add(graph, dict(name="Add_"))
            add_data = Op.create_input_data_node(graph, "data_add_",
                                                 np.array(value))
            Op.expand_node_shape(add_data,
                                 (len(input_shape) -
                                  2 if graph.graph['layout'] == 'NCHW' else 0))
            add_input = Op.create_data_node(graph, input_node,
                                            {'shape': out_node.shape})

            add_node.create_node_with_data(inputs=[add_input, add_data],
                                           data_nodes=out_node)
Exemplo n.º 4
0
    def apply_scale(graph: Graph, input_node: Node,
                    node_mean_scale_values: dict):
        if 'scale' in node_mean_scale_values and node_mean_scale_values[
                'scale'] is not None:
            if all([x == 1 for x in node_mean_scale_values['scale']]):
                return
            out_node = input_node.out_node()
            if not input_node.has_valid('shape'):
                raise Error("Node {} has not valid shape attribute".format(
                    input_node.id))
            input_shape = input_node.shape

            # Create Mul node
            value = 1 / np.array(node_mean_scale_values['scale'])
            graph.remove_edge(input_node.id, out_node.id)

            mul_node = Mul(graph, dict(name="Mul_"))
            mul_data = Op.create_input_data_node(graph, "data_mul_",
                                                 np.array(value))
            Op.expand_node_shape(mul_data,
                                 (len(input_shape) -
                                  2 if graph.graph['layout'] == 'NCHW' else 0))
            mul_input = Op.create_data_node(graph, input_node,
                                            {'shape': out_node.shape})

            mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                           data_nodes=out_node)
Exemplo n.º 5
0
def duplicate_shared_weights(graph: nx.MultiDiGraph):
    """
    This function finds all const data nodes that have more that one consumer and then duplicate them
    """
    data_nodes = [
        Node(graph, id) for id in graph.nodes()
        if Node(graph, id).soft_get('kind') == 'data'
    ]
    for node in data_nodes:
        # Check that node has const values and more than one consumer
        if len(node.out_nodes()) > 1 and node.value is not None:
            # Here we delete all edges between base node and it's consumers (except first), and then duplicate this
            # node to connect with other consumers
            while len(node.out_nodes()) > 1:
                out_node = node.out_node(1)

                if len(graph.get_edge_data(node.id, out_node.id)) != 1:
                    raise Error(
                        'There is more than one edge from {} node to {} node.'.
                        format(node.id, out_node.id))
                e_attrs = graph.get_edge_data(node.id, out_node.id)[0]

                graph.remove_edge(node.id, out_node.id)
                data = Op.create_input_data_node(graph,
                                                 "Copy_{}".format(node.id),
                                                 np.array(node.value),
                                                 graph.node[node.id])

                graph.add_edges_from([(data.id, out_node.id, e_attrs)])
Exemplo n.º 6
0
def _scale_input_action_mul(graph: nx.MultiDiGraph, match: dict, scale: float):
    assert (len(match['placeholder'].out_nodes()))

    tinput = match['placeholder']
    if not tinput.has_valid('shape'):
        raise Error("Node {} has not valid shape attribute".format(tinput.id))

    input_shape = tinput.shape
    toutput = match['data']

    # Create Mul node
    value = np.array([1 / scale])

    # Disconnect input with data node
    graph.remove_edge(tinput.id, toutput.id)

    # Create Mul node
    mul_node = Mul(graph, dict(name="Mul1_"))
    mul_data = Op.create_input_data_node(graph, "data_mul_scale_",
                                         np.array(value))
    Op.expand_node_shape(
        mul_data,
        len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)
    mul_input = Op.create_data_node(graph, tinput, {'shape': toutput.shape})

    mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                   data_nodes=toutput)
Exemplo n.º 7
0
    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
        data_nodes = [
            Node(graph, node) for node in graph.node
            if Node(graph, node).kind == 'data'
        ]
        for node in data_nodes:
            # Get all requested shapes for current node
            # This mapping will contain pairs like {shape:[list of consumers nodes]}
            mapping = {}
            for consumer in node.out_nodes():
                edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
                if 'new_shape' in edge_attrs:
                    if np.array_equal(edge_attrs['new_shape'], node.shape):
                        continue
                    new_shape = tuple([x for x in edge_attrs['new_shape']])
                    if not new_shape in mapping:
                        mapping.update({new_shape: [consumer]})
                    else:
                        mapping[new_shape].append(consumer)

            if node.has_valid('value'):
                # Check that requested shape are the same
                # In case if they are different, we duplicate them
                for shape_key in mapping.keys():
                    shape = list(shape_key)
                    new_value = np.reshape(node.value, shape)
                    node_copy = Op.create_input_data_node(
                        graph, node.id + '/copy', value=np.array(new_value))
                    for consumer in mapping[shape_key]:
                        edge_attrs = graph.get_edge_data(node.id,
                                                         consumer.id)[0]
                        del edge_attrs['new_shape']

                        # Remove edge from previous data node and connect new data node with its consumer
                        graph.remove_edge(node.id, consumer.id)
                        graph.add_edge(node_copy.id, consumer.id, **edge_attrs)
            else:
                # Insert Reshape layer between data node and consumer
                for shape_key in mapping.keys():
                    shape = list(shape_key)
                    reshape = Reshape(graph,
                                      attrs={
                                          'dim': shape,
                                          'name': 'EltwiseReshapeNormalization'
                                      })
                    reshape_data = reshape.create_node_with_data(inputs=[node])

                    # Iterate over consumers and reconnect them to Reshape layer output
                    for consumer in mapping[shape_key]:
                        edge_attrs = graph.get_edge_data(node.id,
                                                         consumer.id)[0]
                        del edge_attrs['new_shape']

                        # Reconnect edge from original data node to Reshape output datanode
                        graph.remove_edge(node.id, consumer.id)
                        graph.add_edge(reshape_data.id, consumer.id,
                                       **edge_attrs)
    def find_and_replace_pattern(self, graph: Graph):
        """
        This function finds all const data nodes that have more that one consumer and then duplicate them
        """
        data_nodes = [Node(graph, id) for id in graph.nodes() if Node(graph, id).soft_get('kind') == 'data']
        for node in data_nodes:
            # Check that node has const values and more than one consumer
            if len(node.in_nodes()) and node.in_node().soft_get('type') == 'Const' and len(node.out_nodes()) > 1 and \
                            node.value is not None:
                # Here we delete all edges between base node and it's consumers (except first), and then duplicate this
                # node to connect with other consumers
                for v, d in node.get_outputs():
                    out_node = Node(graph, v)
                    e_attrs = d
                    graph.remove_edge(node.id, out_node.id)
                    data = Op.create_input_data_node(graph, "Copy_{}".format(node.id), np.array(node.value),
                                                     graph.node[node.id])

                    graph.add_edges_from([(data.id, out_node.id, e_attrs)])
Exemplo n.º 9
0
def _fuse_linear_sequence(graph: nx.MultiDiGraph, start_node: Node):
    """
    This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
    :param graph:
    :param start_node: The first operation of the sequence
    """
    fnodes = [start_node]
    while True:
        node = fnodes[-1]
        data_node = node.out_node()
        if (len(data_node.out_nodes()) != 1):
            break
        if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id(
                data_node.out_node()) is not None and data_node.out_node(
                ).soft_get('can_be_fused') == True:
            fnodes.append(data_node.out_node())
        else:
            break

    if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul'
                            and fnodes[1].op == 'Add'):
        return False

    input_shape = start_node.in_node(get_tensor_id(start_node)).shape

    init_dims_cnt = len(
        input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1

    mul = np.ones([1 for x in range(init_dims_cnt)])
    add = np.zeros([1 for x in range(init_dims_cnt)])

    first_mul_name = None
    first_add_name = None

    for idx in range(len(fnodes)):
        node = fnodes[idx]
        const_node = get_value_id(node)
        if node.op == 'Mul':
            if first_mul_name is None:
                first_mul_name = node.name
            mul = mul * node.in_node(const_node).value
            add = add * node.in_node(const_node).value
        elif node.op == 'Add':
            if first_add_name is None:
                first_add_name = node.name
            add = add + node.in_node(const_node).value

    # If mul is scalar we broadcast it to biases shape
    if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
        mul = np.array([mul[0] for x in range(add.shape[0])])

    assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape,
                           fnodes[-1].out_node().shape))

    mul_node = Mul(
        graph,
        dict(name=first_mul_name +
             '/Fused_Mul_' if first_mul_name is not None else ''))
    add_node = Add(
        graph,
        dict(name=first_add_name +
             '/Fused_Add_' if first_add_name is not None else ''))

    in_node = fnodes[0].in_node(get_tensor_id(fnodes[0]))
    out_node = fnodes[-1].out_node()

    graph.remove_edge(in_node.id, fnodes[0].id)
    graph.remove_edge(fnodes[-1].id, out_node.id)

    # Remove deleted subgraph
    for node in fnodes:
        for tmp_node in node.in_nodes().values():
            # Remove node only if it has one consumer (for case with shared weights)
            if len(tmp_node.out_nodes()) == 1:
                graph.remove_node(tmp_node.id)
        for tmp_node in node.out_nodes().values():
            graph.remove_node(tmp_node.id)
        graph.remove_node(node.id)
    """
    Four cases considered below:
        1. Mul and Add have valid values (mul value != 1 and add value != 0)
        2. Only Mul has valid values, so we add only Mul node
        3. Only Add has valid values, so we add only Add node
        4. When Mul and Add has not valid values we just merge two data nodes
    """
    if any([x != 0
            for x in np.nditer(add)]) and any([x != 1
                                               for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[
            mul_node.create_node_with_data([in_node, data_mul]), data_add
        ],
                                       data_nodes=out_node)
    elif any([x != 1 for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        mul_node.create_node_with_data(inputs=[in_node, data_mul],
                                       data_nodes=out_node)
    elif any([x != 0 for x in np.nditer(add)]):
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[in_node, data_add],
                                       data_nodes=out_node)
    else:
        merge_data_nodes(graph, out_node, in_node)
        graph.remove_node(in_node.id)

    log.debug('Fused {} operations'.format(len(fnodes)))
    return True