Exemplo n.º 1
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        node = match['op']

        if not node.has_valid('bias') or (node.has_valid('bias')
                                          and node.bias == 1):
            return

        # Calculate scale value & create Const op
        scale_value = np.array(1. / (pow(node.bias, node.beta)))
        node.alpha /= node.bias
        const_node = Const(graph,
                           dict(value=scale_value, shape=scale_value.shape))

        # Get all outputs for LRN layer
        out_nodes = [node for node in node.out_nodes().values()]

        # Create Mul node with inputs
        mul_node = Mul(graph, dict(name=node.id + "/Mul_"))
        mnode = mul_node.create_node(inputs=[node, const_node.create_node()])

        # Move edges from LRN to Mul node
        for out_node in out_nodes:
            edge_attrs = graph.get_edge_data(node.id, out_node.id)[0]
            graph.remove_edge(node.id, out_node.id)
            graph.add_edges_from([(mnode.id, out_node.id, edge_attrs)])
Exemplo n.º 2
0
def apply_scale(graph: nx.MultiDiGraph, input_node: Node,
                node_mean_scale_values: dict):
    if 'scale' in node_mean_scale_values and node_mean_scale_values[
            'scale'] is not None:
        if all([x == 1 for x in node_mean_scale_values['scale']]):
            return
        out_node = input_node.out_node()
        if not input_node.has_valid('shape'):
            raise Error("Node {} has not valid shape attribute".format(
                input_node.id))
        input_shape = input_node.shape

        # Create Mul node
        value = 1 / np.array(node_mean_scale_values['scale'])
        graph.remove_edge(input_node.id, out_node.id)

        mul_node = Mul(graph, dict(name="Mul_"))
        mul_data = Op.create_input_data_node(graph, "data_mul_",
                                             np.array(value))
        Op.expand_node_shape(
            mul_data,
            (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0))
        mul_input = Op.create_data_node(graph, input_node,
                                        {'shape': out_node.shape})

        mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                       data_nodes=out_node)
Exemplo n.º 3
0
def _scale_input_action_mul(graph: nx.MultiDiGraph, match: dict, scale: float):
    assert (len(match['placeholder'].out_nodes()))

    tinput = match['placeholder']
    if not tinput.has_valid('shape'):
        raise Error("Node {} has not valid shape attribute".format(tinput.id))

    input_shape = tinput.shape
    toutput = match['data']

    # Create Mul node
    value = np.array([1 / scale])

    # Disconnect input with data node
    graph.remove_edge(tinput.id, toutput.id)

    # Create Mul node
    mul_node = Mul(graph, dict(name="Mul1_"))
    mul_data = Op.create_input_data_node(graph, "data_mul_scale_",
                                         np.array(value))
    Op.expand_node_shape(
        mul_data,
        len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)
    mul_input = Op.create_data_node(graph, tinput, {'shape': toutput.shape})

    mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                   data_nodes=toutput)
Exemplo n.º 4
0
 def replace_op(self, graph: nx.MultiDiGraph, node: Node):
     mul_op = Mul(
         graph,
         dict(name=node.id + '/mul_',
              symbol_dict={'name': node.id + '/mul_'}))
     mul_node = mul_op.create_node(
         inputs=[node.in_node(0), node.in_node(1)])
     replace_node(node, mul_node)
     return [mul_node.id]
Exemplo n.º 5
0
    def replace_op(self, graph: Graph, node: Node):
        prefix = node.name + '/InstanceNormalization'
        mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon))
        mul = Mul(graph, dict(name=prefix + '/Mul', axis=1))
        add = Add(graph, dict(name=prefix + '/Add', axis=1))

        new_subgraph = add.create_node([
            mul.create_node(
                [mvn.create_node([node.in_node(0)]),
                 node.in_node(1)]),
            node.in_node(2)
        ])

        return [new_subgraph.id]
Exemplo n.º 6
0
    def replace_op(self, graph: Graph, node: Node):
        in_node = node.in_node()
        out_nodes = [node for node in node.out_nodes().values()]
        graph.remove_edge(node.in_node().id, node.id)
        scalar_value_op = Const(graph, dict(value=node.scalar, shape=node.scalar.shape, symbol_dict={'name': node.id + '/const'}))
        mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'}))
        mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()])

        for out_node in out_nodes:
            edge_attrs = graph.get_edge_data(node.id, out_node.id)[0]
            graph.remove_edge(node.id, out_node.id)
            graph.add_edges_from([(mul_node.id, out_node.id, edge_attrs)])

        return [mul_node.id]
Exemplo n.º 7
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['softmax']
     if 'temperature' in node and node['temperature'] != 1.0:
         in_node = node.in_node()
         out_nodes = [node for node in node.out_nodes().values()]
         graph.remove_edge(node.in_node().id, node.id)
         temperature = np.array([1.0 / node.temperature])
         scalar_value_op = Const(
             graph,
             dict(value=temperature,
                  shape=temperature.shape,
                  symbol_dict={'name': node.id + '/const'}))
         mul_op = Mul(
             graph,
             dict(name=node.id + '/mul_',
                  symbol_dict={'name': node.id + '/mul_'}))
         mul_node = mul_op.create_node(
             inputs=[in_node, scalar_value_op.create_node()])
         edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0]
         graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
Exemplo n.º 8
0
def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
                                    mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add subgraph
    """
    shape = tinput.shape

    # Create first Mul & Add operations
    mul1_node = Mul(graph, dict(name="Mul1_", can_be_fused=can_be_fused))
    add1_node = Add(graph, dict(name="Add1_", can_be_fused=can_be_fused))

    mul1_data = Op.create_input_data_node(graph, "data_mul_", np.array(mean))
    add1_data = Op.create_input_data_node(graph, "data_add_", np.array(variance))

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.shape[0] != gamma.value.shape[0]:
        gamma.value.resize(gamma.shape)
        gamma.value.fill(gamma.value[0])

    # Create second Mul & Add
    mul2_node = Mul(graph, dict(name="Mul2_", can_be_fused=can_be_fused))
    add2_node = Add(graph, dict(name="Add2_", can_be_fused=can_be_fused))

    add2_node.create_node_with_data(
        inputs=[mul2_node.create_node_with_data(
            inputs=[add1_node.create_node_with_data(
                inputs=[mul1_node.create_node_with_data(inputs=[tinput, mul1_data]),
                        add1_data]),
                gamma]),
            beta],
        data_nodes=toutput)
Exemplo n.º 9
0
def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict):
    # Data nodes
    tinput = match['input']
    toutput = match['output']
    mean = match['mean']
    variance = match['variance']

    # Op node
    bn_node = match['batch_norm']

    # Disconnect data nodes from
    graph.remove_edge(tinput.node, bn_node.node)
    graph.remove_edge(mean.node, bn_node.node)
    graph.remove_edge(variance.node, bn_node.node)

    graph.remove_edge(bn_node.node, toutput.node)

    scale = 1. / np.sqrt(variance.value + bn_node.epsilon)
    shift = (mean.value * (-1)) * scale

    mean.value = np.array(scale)
    variance.value = np.array(shift)

    # Expand dims for current layout
    broadcast_dims_cnt = len(
        tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
    # Update values and shapes with new shape
    Op.expand_node_shape(mean, broadcast_dims_cnt)
    Op.expand_node_shape(variance, broadcast_dims_cnt)

    can_be_fused = False if not bn_node.soft_get('can_be_fused') else True

    mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused))
    add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused))

    # Connect input->mul->add
    add_node.create_node_with_data(inputs=[
        mul_node.create_node_with_data(inputs=[tinput, mean]), variance
    ],
                                   data_nodes=toutput)
Exemplo n.º 10
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        # This replacer replace ImageScalar operation to Mul->Add sequence
        # Also it check that weights and biases are good
        op = match['op']

        # Check that weights and biases are not useless
        has_bias, has_weights = True, True
        if all([x == 1 for x in np.nditer(op.scale)]):
            has_weights = False
        if all([x == 0 for x in np.nditer(op.bias)]):
            has_bias = False

        # Get all outputs for op node
        out_nodes = [node for node in op.out_nodes().values()]

        assert len(op.in_nodes()) == 1

        last_node = op.in_node()
        # Create Mul & Add nodes
        if has_weights:
            mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape))
            mul_op = Mul(graph, dict(name=op.id + '/mul_'))
            last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()])

        if has_bias:
            add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape))
            add_op = Add(graph, dict(name=op.id + '/add_'))
            last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()])

        # Move edges from ImageScaler to last_node (Mul or Add)
        for out_node in out_nodes:
            edge_attrs = graph.get_edge_data(op.id, out_node.id)[0]
            graph.remove_edge(op.id, out_node.id)
            graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)])

        # Disconnect ImageScalar node
        graph.remove_edge(op.in_node().id, op.id)
Exemplo n.º 11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # This replacer replace ImageScalar operation to Mul->Add sequence
        # Also it check that weights and biases are good
        op = match['op']

        # Check that weights and biases are not useless
        has_bias, has_weights = True, True
        if all([x == 1 for x in np.nditer(op.scale)]):
            has_weights = False
        if all([x == 0 for x in np.nditer(op.bias)]):
            has_bias = False

        assert len(op.in_ports()) == 1

        last_port = op.in_port(0).get_source()

        # Create Mul & Add nodes
        if has_weights:
            mul_weights = Const(graph,
                                dict(value=op.scale,
                                     shape=op.scale.shape)).create_node()
            mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
            op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
            mul_weights.out_port(0).connect(mul_op.in_port(1))
            last_port = mul_op.out_port(0)

        if has_bias:
            add_bias = Const(graph, dict(value=op.bias,
                                         shape=op.bias.shape)).create_node()
            add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
            last_port.get_connection().set_destination(add_op.in_port(0))
            add_bias.out_port(0).connect(add_op.in_port(1))
            last_port = add_op.out_port(0)

        op.in_port(0).disconnect()
        op.out_port(0).get_connection().set_source(last_port)
Exemplo n.º 12
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        quantize = match['quantize']
        # This pass is applicable for binarization only. Other intX variants are not relevant.
        if quantize.levels != 2:
            return

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                      ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)
        assert len(output_low.out_nodes()) == 1
        assert len(output_high.out_nodes()) == 1

        if not output_low.has_valid('value') and not output_high.has_valid('value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        operator = match['operator']

        if np.all(np.isclose(output_low, 0)) and np.all(np.isclose(output_high, 1)):

            weights = operator.in_node(1).value
            reduction_indices = set(range(len(weights.shape))) - set([operator.output_feature_channel])
            weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
            weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1])

            add_term = Const(graph, {'value': weights_reduced}).create_node()
            add = Add(graph, {}).create_node()
            add.in_port(1).connect(add_term.out_port(0))
            mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
            mul = Mul(graph, {}).create_node()
            mul.in_port(1).connect(mul_term.out_port(0))
            add.out_port(0).connect(mul.in_port(0))

            operator.out_port(0).get_connection().set_source(mul.out_port(0))
            add.in_port(0).connect(operator.out_port(0))

            operator['pad_value'] = float(-1.0)
        elif np.all(np.isclose(output_low, -1)) and np.all(np.isclose(output_high, +1)):
            pass
        else:
            log.debug('ConvToBinaryConv: cannot apply transformation because input range is neither in [0, +1] nor '
                      'in [-1, +1].')
            return

        operator['type'] = 'BinaryConvolution'
        operator['mode'] = 'xnor-popcount'
        operator['input'] = operator.in_node(0).shape[1]
        # Weights are not bit-packed yet; there should be a separate transformation to do that

        assert output_low.size == 1
        assert output_high.size == 1

        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        # Make sure that low/high values are exactly 0/1
        output_low.value = np.zeros(output_low.shape)
        output_high.value = np.ones(output_high.shape)
Exemplo n.º 13
0
def convert_scale_shift_to_mul_add(graph: Graph):
    nodes = graph.get_op_nodes(op='ScaleShift')
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        ports_count = len(node.in_ports())

        input_port = node.in_port(0)
        scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
        shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
        output_port = node.out_port(0)

        has_biases = True
        has_weights = True

        # We don't need zero biases
        if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
            has_biases = False

        # We don't need weights with ones
        if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
            has_weights = False

        mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
        add_op = Add(graph, dict(name=node.name + "/Add_"))

        # Expand dims for current layout
        broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0

        # In case if we have constant weights/biases we have to broadcast them according to graph layout
        # otherwise we insert Reshape with broadcast dim attribute.
        def broadcast_value(port):
            value = np.array(port.data.get_value())
            for idx in range(broadcast_dims_cnt):
                value = np.expand_dims(value, axis=-1)
            port.data.set_value(value)

        def broadcast_with_reshape(port):
            input_shape = input_port.data.get_shape()
            reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            data_shape = port.data.get_shape()
            for i in range(node.axis, node.axis + len(data_shape)):
                reshape_dims[i] = data_shape[i - node.axis]
            for i in range(node.axis + len(data_shape), len(input_shape)):
                reshape_dims[i] = 1
            reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
            port.get_connection().set_destination(reshape.in_port(0))
            reshape.out_port(0).connect(port)

        if has_weights and scale_port.data.get_value() is not None:
            broadcast_value(scale_port)
        elif has_weights:
            broadcast_with_reshape(scale_port)

        if has_biases and shift_port.data.get_value() is not None:
            broadcast_value(shift_port)
        elif has_biases:
            broadcast_with_reshape(shift_port)

        if has_biases and has_weights:
            # Connect input->mul->out->add->out
            add_node = add_op.create_node()
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            # Connect Add operation with inputs
            mul_node.out_port(0).connect(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        elif has_weights:
            # Connect input->mul->out
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            output_port.get_connection().set_source(mul_node.out_port(0))
        elif has_biases:
            # Connect input->add->out
            add_node = add_op.create_node()

            # Connect Add operation with inputs
            input_port.get_connection().set_destination(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        else:
            # Connect input->out
            producer_port = input_port.get_source()
            input_port.disconnect()
            output_port.get_connection().set_source(producer_port)
Exemplo n.º 14
0
def _fuse_linear_sequence(graph: nx.MultiDiGraph, start_node: Node):
    """
    This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
    :param graph:
    :param start_node: The first operation of the sequence
    """
    fnodes = [start_node]
    while True:
        node = fnodes[-1]
        data_node = node.out_node()
        if (len(data_node.out_nodes()) != 1):
            break
        if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id(
                data_node.out_node()) is not None and data_node.out_node(
                ).soft_get('can_be_fused') == True:
            fnodes.append(data_node.out_node())
        else:
            break

    if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul'
                            and fnodes[1].op == 'Add'):
        return False

    input_shape = start_node.in_node(get_tensor_id(start_node)).shape

    init_dims_cnt = len(
        input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1

    mul = np.ones([1 for x in range(init_dims_cnt)])
    add = np.zeros([1 for x in range(init_dims_cnt)])

    first_mul_name = None
    first_add_name = None

    for idx in range(len(fnodes)):
        node = fnodes[idx]
        const_node = get_value_id(node)
        if node.op == 'Mul':
            if first_mul_name is None:
                first_mul_name = node.name
            mul = mul * node.in_node(const_node).value
            add = add * node.in_node(const_node).value
        elif node.op == 'Add':
            if first_add_name is None:
                first_add_name = node.name
            add = add + node.in_node(const_node).value

    # If mul is scalar we broadcast it to biases shape
    if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
        mul = np.array([mul[0] for x in range(add.shape[0])])

    assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape,
                           fnodes[-1].out_node().shape))

    mul_node = Mul(
        graph,
        dict(name=first_mul_name +
             '/Fused_Mul_' if first_mul_name is not None else ''))
    add_node = Add(
        graph,
        dict(name=first_add_name +
             '/Fused_Add_' if first_add_name is not None else ''))

    in_node = fnodes[0].in_node(get_tensor_id(fnodes[0]))
    out_node = fnodes[-1].out_node()

    graph.remove_edge(in_node.id, fnodes[0].id)
    graph.remove_edge(fnodes[-1].id, out_node.id)

    # Remove deleted subgraph
    for node in fnodes:
        for tmp_node in node.in_nodes().values():
            # Remove node only if it has one consumer (for case with shared weights)
            if len(tmp_node.out_nodes()) == 1:
                graph.remove_node(tmp_node.id)
        for tmp_node in node.out_nodes().values():
            graph.remove_node(tmp_node.id)
        graph.remove_node(node.id)
    """
    Four cases considered below:
        1. Mul and Add have valid values (mul value != 1 and add value != 0)
        2. Only Mul has valid values, so we add only Mul node
        3. Only Add has valid values, so we add only Add node
        4. When Mul and Add has not valid values we just merge two data nodes
    """
    if any([x != 0
            for x in np.nditer(add)]) and any([x != 1
                                               for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[
            mul_node.create_node_with_data([in_node, data_mul]), data_add
        ],
                                       data_nodes=out_node)
    elif any([x != 1 for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        mul_node.create_node_with_data(inputs=[in_node, data_mul],
                                       data_nodes=out_node)
    elif any([x != 0 for x in np.nditer(add)]):
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[in_node, data_add],
                                       data_nodes=out_node)
    else:
        merge_data_nodes(graph, out_node, in_node)
        graph.remove_node(in_node.id)

    log.debug('Fused {} operations'.format(len(fnodes)))
    return True
Exemplo n.º 15
0
 def extract(node):
     axis = onnx_attr(node, 'axis', 'i', default=None)
     Mul.update_node_stat(node, {'axis': axis})
     return __class__.enabled
Exemplo n.º 16
0
def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph):
    nodes = [
        Node(graph, node) for node in graph.nodes()
        if Node(graph, node).soft_get('op') == 'ScaleShift'
    ]
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        has_biases = True
        has_weights = True
        # We don't need zero biases
        if len(node.in_nodes()) < 3 or all(
            [x == 0 for x in node.in_node(2).value]):
            has_biases = False
        input_node = node.in_node(0)
        scale_node = node.in_node(1)
        shift_node = node.in_node(2) if has_biases else None
        output_node = node.out_node()

        if scale_node.has_valid("value") and all(
            [x == 1 for x in scale_node.value]):
            has_weights = False

        mul_node = Mul(graph, dict(name=node.name + "/Mul_"))
        add_node = Add(graph, dict(name=node.name + "/Add_"))

        # Disconnect ScaleShift node
        graph.remove_edge(input_node.id, node.id)
        graph.remove_edge(node.id, output_node.id)

        # Expand dims for current layout
        broadcast_dims_cnt = len(
            input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
        if scale_node.has_valid("value"):
            Op.expand_node_shape(scale_node, broadcast_dims_cnt)
        else:
            # insert reshape to make shapes similar
            reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            for i in range(node.axis, node.axis + len(scale_node.shape)):
                reshape_dims[i] = scale_node.shape[i - node.axis]
            for i in range(node.axis + len(scale_node.shape),
                           len(input_node.shape)):
                reshape_dims[i] = 1
            reshape = Reshape(
                graph,
                dict(name=scale_node.name + "/Broadcast_", dim=reshape_dims))
            scale_node = reshape.create_node_with_data(inputs=[scale_node])

        Op.expand_node_shape(shift_node, broadcast_dims_cnt)

        # Connect input->mul->out->add->out
        if has_biases:
            add_node.create_node_with_data(inputs=[
                mul_node.create_node_with_data(
                    inputs=[input_node, scale_node]), shift_node
            ],
                                           data_nodes=output_node)
        elif has_weights:
            mul_node.create_node_with_data(inputs=[input_node, scale_node],
                                           data_nodes=output_node)
        else:
            merge_data_nodes(graph, input_node, output_node)
            graph.remove_node(output_node.id)