Exemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
        if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))

        mvn = MVN(graph, dict(
            name=fbn.name + '/MVN_',
            eps=fbn.eps,
            eps_mode='outside_sqrt',
            normalize_variance=1
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]),
            input_beta
        ])
        fbn.replace_node(new_subgraph)
Exemplo n.º 2
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['softmax']
     if 'temperature' in node and node['temperature'] != 1.0:
         in_node = node.in_node()
         out_nodes = [node for node in node.out_nodes().values()]
         graph.remove_edge(node.in_node().id, node.id)
         temperature = np.array([1.0 / node.temperature])
         scalar_value_op = Const(graph, dict(value=temperature, shape=temperature.shape,
                                             symbol_dict={'name': node.id + '/const'}))
         mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'}))
         mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()])
         edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0]
         graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
Exemplo n.º 3
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(
            input.op, input.name))
        if input.id != match['mean'].in_node(
                0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(
            input.op, input.name))
        MVN = Op.get_op_class_by_name('MVN')

        mvn = MVN(
            graph,
            dict(name=fbn.name + '/MVN_',
                 eps=fbn.eps,
                 required_reduction_indices=[1, 2]
                 if fbn.data_format == b'NHWC' else [2, 3]))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]), input_beta
        ])
        fbn.replace_node(new_subgraph)
Exemplo n.º 4
0
def convert_scale_shift_to_mul_add(graph: Graph):
    nodes = graph.get_op_nodes(op='ScaleShift')
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        ports_count = len(node.in_ports())

        input_port = node.in_port(0)
        scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(
            1).disconnected() else None
        shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(
            2).disconnected() else None
        output_port = node.out_port(0)

        has_biases = True
        has_weights = True

        # We don't need zero biases
        if shift_port is None or (
                shift_port.data.get_value() is not None
                and all([x == 0 for x in shift_port.data.get_value()])):
            has_biases = False

        # We don't need weights with ones
        if scale_port is None or (
                scale_port.data.get_value() is not None
                and all([x == 1 for x in scale_port.data.get_value()])):
            has_weights = False

        mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
        add_op = Add(graph, dict(name=node.name + "/Add_"))

        # Expand dims for current layout
        broadcast_dims_cnt = len(input_port.data.get_shape(
        )) - 2 if graph.graph['layout'] == 'NCHW' else 0

        # In case if we have constant weights/biases we have to broadcast them according to graph layout
        # otherwise we insert Reshape with broadcast dim attribute.
        def broadcast_value(port):
            value = np.array(port.data.get_value())
            for idx in range(broadcast_dims_cnt):
                value = np.expand_dims(value, axis=-1)
            port.data.set_value(value)

        def broadcast_with_reshape(port):
            input_shape = input_port.data.get_shape()
            reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            data_shape = port.data.get_shape()
            for i in range(node.axis, node.axis + len(data_shape)):
                reshape_dims[i] = data_shape[i - node.axis]
            for i in range(node.axis + len(data_shape), len(input_shape)):
                reshape_dims[i] = 1
            reshape = create_op_node_with_second_input(
                graph, Reshape, reshape_dims,
                dict(name=port.node.name + "/Broadcast_"))
            port.get_connection().set_destination(reshape.in_port(0))
            reshape.out_port(0).connect(port)

        if has_weights and scale_port.data.get_value() is not None:
            broadcast_value(scale_port)
        elif has_weights:
            broadcast_with_reshape(scale_port)

        if has_biases and shift_port.data.get_value() is not None:
            broadcast_value(shift_port)
        elif has_biases:
            broadcast_with_reshape(shift_port)

        if has_biases and has_weights:
            # Connect input->mul->out->add->out
            add_node = add_op.create_node()
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            # Connect Add operation with inputs
            mul_node.out_port(0).connect(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        elif has_weights:
            # Connect input->mul->out
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            output_port.get_connection().set_source(mul_node.out_port(0))
        elif has_biases:
            # Connect input->add->out
            add_node = add_op.create_node()

            # Connect Add operation with inputs
            input_port.get_connection().set_destination(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        else:
            # Connect input->out
            producer_port = input_port.get_source()
            input_port.disconnect()
            output_port.get_connection().set_source(producer_port)
Exemplo n.º 5
0
def _fuse_linear_sequence(graph: Graph, start_node: Node):
    """
    This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
    :param graph:
    :param start_node: The first operation of the sequence
    """
    fnodes = [start_node]
    while True:
        node = fnodes[-1]
        destinations = node.out_port(0).get_destinations()
        if len(destinations) != 1:
            break
        dst_node = destinations[0].node
        if dst_node.soft_get('op') in [
                'Mul', 'Add'
        ] and get_value_in_port(dst_node) is not None and dst_node.soft_get(
                'can_be_fused') is True:
            fnodes.append(dst_node)
        else:
            break

    if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul'
                            and fnodes[1].op == 'Add'):
        return False

    input_shape = get_tensor_in_port(start_node).data.get_shape()

    init_dims_cnt = len(
        input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1

    mul = np.ones([1 for x in range(init_dims_cnt)])
    add = np.zeros([1 for x in range(init_dims_cnt)])

    first_mul_name = None
    first_add_name = None

    for node in fnodes:
        const_port_value = get_value_in_port(node).data.get_value()
        if node.op == 'Mul':
            if first_mul_name is None:
                first_mul_name = node.name
            mul = mul * const_port_value
            add = add * const_port_value
        elif node.op == 'Add':
            if first_add_name is None:
                first_add_name = node.name
            add = add + const_port_value

    # If mul is scalar we broadcast it to biases shape
    if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
        mul = np.array([mul[0] for x in range(add.shape[0])])

    assert (np.array_equal(
        get_tensor_in_port(fnodes[0]).data.get_shape(),
        fnodes[-1].out_port(0).data.get_shape()))

    mul_op = Mul(graph,
                 dict(name='{}/Fused_Mul_'.format(first_mul_name or '')))
    add_op = Add(graph,
                 dict(name='{}/Fused_Add_'.format(first_add_name or '')))

    in_port = get_tensor_in_port(fnodes[0])
    out_port = fnodes[-1].out_port(0)
    """
    Four cases considered below:
        1. Mul and Add have valid values (mul value != 1 and add value != 0)
        2. Only Mul has valid values, so we add only Mul node
        3. Only Add has valid values, so we add only Add node
        4. When Mul and Add has not valid values we just merge two data nodes
    """
    if any([x != 0
            for x in np.nditer(add)]) and any([x != 1
                                               for x in np.nditer(mul)]):
        #  Const\    Const\
        #  ----->Mul------>Add-->
        mul_const = Const(graph, dict(name="data_mul_",
                                      value=np.array(mul))).create_node()
        add_const = Const(graph, dict(name="data_add_",
                                      value=np.array(add))).create_node()

        mul_node = mul_op.create_node()
        add_node = add_op.create_node()

        in_port.get_connection().set_destination(mul_node.in_port(0))
        mul_const.out_port(0).connect(mul_node.in_port(1))

        mul_node.out_port(0).connect(add_node.in_port(0))
        add_const.out_port(0).connect(add_node.in_port(1))
        out_port.get_connection().set_source(add_node.out_port(0))
    elif any([x != 1 for x in np.nditer(mul)]):
        #  Const\
        #  ----->Mul-->
        mul_const = Const(graph, dict(name="data_mul_",
                                      value=np.array(mul))).create_node()
        mul_node = mul_op.create_node()

        in_port.get_connection().set_destination(mul_node.in_port(0))
        mul_const.out_port(0).connect(mul_node.in_port(1))
        out_port.get_connection().set_source(mul_node.out_port(0))
    elif any([x != 0 for x in np.nditer(add)]):
        #  Const\
        #  ----->Add-->
        add_const = Const(graph, dict(name="data_add_",
                                      value=np.array(add))).create_node()
        add_node = add_op.create_node()

        in_port.get_connection().set_destination(add_node.in_port(0))
        add_const.out_port(0).connect(add_node.in_port(1))
        out_port.get_connection().set_source(add_node.out_port(0))
    else:
        source_node = in_port.get_source()
        in_port.disconnect()
        out_port.get_connection().set_source(source_node)

    # Remove fused nodes
    for node in fnodes:
        graph.remove_node(node.id)

    log.debug('Fused {} operations'.format(len(fnodes)))
    return True