Example #1
0
File: axpy.py Project: pc2/CustoNN2
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        in_node_0 = node.in_node(0)
        in_node_1 = node.in_node(1)
        in_node_2 = node.in_node(2)

        ss = ScaleShiftOp(graph, {'name': node.id + "/ScaleShift_", 'axis': 0})
        scale_shift = ss.create_node(inputs=[in_node_1, in_node_0])

        el = Add(graph, {'name': node.id + "/Add_"})
        el_node = el.create_node(inputs=[scale_shift, in_node_2])

        return [el_node.id]
Example #2
0
    def replace_op(self, graph: Graph, node: Node):
        prefix = node.name + '/InstanceNormalization'
        mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon))
        mul = Mul(graph, dict(name=prefix + '/Mul', axis=1))
        add = Add(graph, dict(name=prefix + '/Add', axis=1))

        new_subgraph = add.create_node([
            mul.create_node(
                [mvn.create_node([node.in_node(0)]),
                 node.in_node(1)]),
            node.in_node(2)
        ])

        return [new_subgraph.id]
Example #3
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        in_node = node.in_node()
        out_nodes = [node for node in node.out_nodes().values()]
        graph.remove_edge(node.in_node().id, node.id)

        scalar_value_op = Const(graph, dict(value=node.scalar, shape=node.scalar.shape, symbol_dict={'name': node.id + '/const'}))
        add_op = Add(graph, dict(name=node.id + '/add_', symbol_dict={'name': node.id + '/add_'}))
        add_node = add_op.create_node(inputs=[in_node, scalar_value_op.create_node()])

        for out_node in out_nodes:
            edge_attrs = graph.get_edge_data(node.id, out_node.id)[0]
            graph.remove_edge(node.id, out_node.id)
            graph.add_edges_from([(add_node.id, out_node.id, edge_attrs)])

        return [add_node.id]
Example #4
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        # This replacer replace ImageScalar operation to Mul->Add sequence
        # Also it check that weights and biases are good
        op = match['op']

        # Check that weights and biases are not useless
        has_bias, has_weights = True, True
        if all([x == 1 for x in np.nditer(op.scale)]):
            has_weights = False
        if all([x == 0 for x in np.nditer(op.bias)]):
            has_bias = False

        # Get all outputs for op node
        out_nodes = [node for node in op.out_nodes().values()]

        assert len(op.in_nodes()) == 1

        last_node = op.in_node()
        # Create Mul & Add nodes
        if has_weights:
            mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape))
            mul_op = Mul(graph, dict(name=op.id + '/mul_'))
            last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()])

        if has_bias:
            add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape))
            add_op = Add(graph, dict(name=op.id + '/add_'))
            last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()])

        # Move edges from ImageScaler to last_node (Mul or Add)
        for out_node in out_nodes:
            edge_attrs = graph.get_edge_data(op.id, out_node.id)[0]
            graph.remove_edge(op.id, out_node.id)
            graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)])

        # Disconnect ImageScalar node
        graph.remove_edge(op.in_node().id, op.id)
Example #5
0
def convert_scale_shift_to_mul_add(graph: Graph):
    nodes = graph.get_op_nodes(op='ScaleShift')
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        ports_count = len(node.in_ports())

        input_port = node.in_port(0)
        scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
        shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
        output_port = node.out_port(0)

        has_biases = True
        has_weights = True

        # We don't need zero biases
        if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
            has_biases = False

        # We don't need weights with ones
        if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
            has_weights = False

        mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
        add_op = Add(graph, dict(name=node.name + "/Add_"))

        # Expand dims for current layout
        broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0

        # In case if we have constant weights/biases we have to broadcast them according to graph layout
        # otherwise we insert Reshape with broadcast dim attribute.
        def broadcast_value(port):
            value = np.array(port.data.get_value())
            for idx in range(broadcast_dims_cnt):
                value = np.expand_dims(value, axis=-1)
            port.data.set_value(value)

        def broadcast_with_reshape(port):
            input_shape = input_port.data.get_shape()
            reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            data_shape = port.data.get_shape()
            for i in range(node.axis, node.axis + len(data_shape)):
                reshape_dims[i] = data_shape[i - node.axis]
            for i in range(node.axis + len(data_shape), len(input_shape)):
                reshape_dims[i] = 1
            reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
            port.get_connection().set_destination(reshape.in_port(0))
            reshape.out_port(0).connect(port)

        if has_weights and scale_port.data.get_value() is not None:
            broadcast_value(scale_port)
        elif has_weights:
            broadcast_with_reshape(scale_port)

        if has_biases and shift_port.data.get_value() is not None:
            broadcast_value(shift_port)
        elif has_biases:
            broadcast_with_reshape(shift_port)

        if has_biases and has_weights:
            # Connect input->mul->out->add->out
            add_node = add_op.create_node()
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            # Connect Add operation with inputs
            mul_node.out_port(0).connect(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        elif has_weights:
            # Connect input->mul->out
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            output_port.get_connection().set_source(mul_node.out_port(0))
        elif has_biases:
            # Connect input->add->out
            add_node = add_op.create_node()

            # Connect Add operation with inputs
            input_port.get_connection().set_destination(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        else:
            # Connect input->out
            producer_port = input_port.get_source()
            input_port.disconnect()
            output_port.get_connection().set_source(producer_port)