def replace_op(self, graph: Graph, node: Node):

        # Add new nodes
        mvn = MVN(graph, {
            'eps': node.epsilon,
            'name': node.name + '/Ins_Norm/MVN_',
        }).create_node()
        mul = Mul(graph, {
            'axis': 1,
            'name': node.name + '/Ins_Norm/mul_'
        }).create_node()
        add = Add(graph, {
            'axis': 1,
            'name': node.name + '/Ins_Norm/add_'
        }).create_node()

        # Connect nodes
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.out_port(0).connect(mul.in_port(0))
        mul.out_port(0).connect(add.in_port(0))

        return [add.id]
示例#2
0
def _fused_batch_norm_decomposition(graph: Graph,
                                    tinput: Port,
                                    toutput: Port,
                                    gamma: Port,
                                    beta: Port,
                                    mean: np.ndarray,
                                    variance: np.ndarray,
                                    can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add sub graph
    """
    batch_norm_name = tinput.get_connection().get_destination().node.name

    # Create first Mul & Add operations
    mul1_node = Mul(
        graph, dict(name=batch_norm_name + "/mean",
                    can_be_fused=can_be_fused)).create_node()
    add1_node = Add(
        graph,
        dict(name=batch_norm_name + "/variance",
             can_be_fused=can_be_fused)).create_node()

    const_mul1_node = Const(graph, dict(name="data_mul_",
                                        value=np.array(mean))).create_node()
    const_add1_node = Const(graph,
                            dict(name="data_add_",
                                 value=np.array(variance))).create_node()

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]:
        value = gamma.data.get_value()
        value.resize(gamma.data.get_shape()).fill(value[0])
        gamma.data.set_value(value)

    # Create second Mul & Add
    mul2_node = Mul(
        graph, dict(name=batch_norm_name + "/gamma",
                    can_be_fused=can_be_fused)).create_node()
    add2_node = Add(
        graph, dict(name=batch_norm_name + "/beta",
                    can_be_fused=can_be_fused)).create_node()

    # Connect edges Mul1->Add1->Mul2->Add2
    tinput.get_connection().set_destination(mul1_node.in_port(0))
    mul1_node.in_port(1).get_connection().set_source(
        const_mul1_node.out_port(0))

    add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0))
    add1_node.in_port(1).get_connection().set_source(
        const_add1_node.out_port(0))

    mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0))
    gamma.get_connection().set_destination(mul2_node.in_port(1))

    add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0))
    beta.get_connection().set_destination(add2_node.in_port(1))

    toutput.get_connection().set_source(add2_node.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: dict):
        resize_node = match['resize']
        if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \
                match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value:
            log.info(
                'Pattern matched around resize op {} has different scale values.'
                .format(resize_node.name))
            return

        interpolate_node = Interpolate(
            graph, {
                'name': resize_node.name + '/Interpolate',
                'mode': resize_node.mode,
                'axes': int64_array([2, 3, 4])
            }).create_node()

        scale = match['mul_1'].in_node(1).value
        scale_value = int64_array([scale, scale, scale])
        scale_const = Const(graph, {
            'value': scale_value,
            'name': resize_node.name + '/Scale'
        }).create_node()

        interpolated_shape = Mul(graph, {
            'name': resize_node.name + '/OutputShape'
        }).create_node()
        match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
        scale_const.out_port(0).connect(interpolated_shape.in_port(1))

        resize_node.in_port(0).get_connection().set_destination(
            interpolate_node.in_port(0))
        interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
        resize_node.out_port(0).get_connection().set_source(
            interpolate_node.out_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = np.array([get_split_scale(split)], dtype=np.float32)
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
示例#5
0
    def div_to_mul_replacement(div: Node):
        # we execute this transformation for V10 IR later on middle phase despite graph_condition
        # so we prevent Div replacement on shape-calculating sub-graphs
        if div.in_port(0).data.get_value() is not None and div.in_port(
                1).data.get_value() is not None:
            return

        graph = div.graph
        name = div.soft_get('name', div.id)

        # keep Mul name the same as Div -- because of mathematical equality of output tensors
        rename_node(node=div, name=name + '/to_be_removed')

        # reconnect Div in(out)puts to Mul
        mul = Mul(graph, {'name': name}).create_node()
        rename_node(mul, name)

        div.in_port(0).get_connection().set_destination(mul.in_port(0))
        div.in_port(1).get_connection().set_destination(mul.in_port(1))
        div.out_port(0).get_connection().set_source(mul.out_port(0))

        # restore mathematical equivalence to Div operation: Div(A, B) = Mul(A, Pow(B, -1))
        reciprocal = create_op_with_const_inputs(
            graph, Pow, {1: np.float64(-1)}, {'name': name + '/reciprocal_'})
        mul.in_port(1).get_connection().insert_node(reciprocal)
示例#6
0
    def replace_op(self, graph: Graph, node: Node):

        # Add new nodes
        const = Const(graph,
                      dict(value=np.array(-1, dtype=np.int32))).create_node()
        negate = Mul(graph, {'name': node.name + '/negate_'}).create_node()
        add = Add(graph, {'name': node.name + '/add_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(negate.in_port(0))
        const.out_port(0).connect(negate.in_port(1))
        node.in_port(0).get_connection().set_destination(add.in_port(1))
        negate.out_port(0).connect(add.in_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [add.id]
    def replace_pattern(graph: Graph, match: dict):
        node = match['normalize']
        assert node.in_port(0).data.get_shape().size in [2, 3, 4]
        assert node.has_valid('across_spatial')
        assert node.has_valid('channel_shared')
        assert node.has_valid('eps')

        if 'bin' in node.in_edge(1):
            del node.in_edge(1)['bin']

        weights = node.in_port(1).data.get_value()
        if node.channel_shared or all(weights == weights[0]):
            node.in_port(1).data.set_value(np.array([weights[0]]))
        assert weights is not None

        mul = Mul(graph, {
            'name': node.name + '/Normalize_weights_multiplication'
        }).create_node()
        node.out_port(0).get_connection().set_source(mul.out_port(0))
        node.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().get_source().connect(mul.in_port(1))
        node.in_port(1).disconnect()

        node['type'] = 'NormalizeL2'
        node['eps_mode'] = 'add'
        node['force_precision_in_ports'] = {1: 'int64'}

        axes_val = np.array([1]) if not node.across_spatial else \
            np.arange(start=1, stop=node.in_port(0).data.get_shape().size)
        axes = Const(graph, {'value': axes_val}).create_node()
        node.in_port(1).connect(axes.out_port(0))

        del node['across_spatial']
        del node['channel_shared']
示例#8
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        op = match['op']
        out_port = op.in_port(0).get_source()

        if op.soft_get('scale', 1) != 1:
            const = Const(graph, {'value': np.array(op.scale)}).create_node()
            mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
            const.out_port(0).connect(mul.in_port(1))
            out_port.connect(mul.in_port(0))
            out_port = mul.out_port(0)

        if op.soft_get('shift', 0) != 0:
            const = Const(graph, {'value': np.array(op.shift)}).create_node()
            add = Add(graph, {'name': op.name + '/add_'}).create_node()
            const.out_port(0).connect(add.in_port(1))
            out_port.connect(add.in_port(0))
            out_port = add.out_port(0)

        if op.soft_get('power', 1) != 1:
            const = Const(graph, {'value': np.array(op.power)}).create_node()
            pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
            const.out_port(0).connect(pow.in_port(1))
            out_port.connect(pow.in_port(0))
            out_port = pow.out_port(0)

        op.out_port(0).get_connection().set_source(out_port)
示例#9
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if not node.has_valid('bias') or (node.has_valid('bias')
                                          and node.bias == 1):
            return

        # Calculate scale value & create Const op
        scale_value = np.array(1. / (pow(node.bias, node.beta)))
        node.alpha /= node.bias
        const_node = Const(
            graph, {
                'value': scale_value,
                'shape': scale_value.shape,
                'name': node.name + "/Const_Mul_"
            }).create_node()

        # Create Mul node
        mul_node = Mul(graph, {'name': node.name + "/Mul_"}).create_node()

        # Connect nodes
        const_node.out_port(0).connect(mul_node.in_port(1))
        node.out_port(0).get_connection().set_source(mul_node.out_port(0))
        node.out_port(0).connect(mul_node.in_port(0))

        # Delete bias, if it is not deleted it will appear in IR v6
        del node['bias']
示例#10
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # This replacer replace ImageScalar operation to Mul->Add sequence
        # Also it check that weights and biases are good
        op = match['op']

        # Check that weights and biases are not useless
        has_bias, has_weights = True, True
        if all([x == 1 for x in np.nditer(op.scale)]):
            has_weights = False
        if all([x == 0 for x in np.nditer(op.bias)]):
            has_bias = False

        assert len(op.in_ports()) == 1

        last_port = op.in_port(0).get_source()

        # Create Mul & Add nodes
        if has_weights:
            mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node()
            mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
            op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
            mul_weights.out_port(0).connect(mul_op.in_port(1))
            last_port = mul_op.out_port(0)

        if has_bias:
            add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node()
            add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
            last_port.get_connection().set_destination(add_op.in_port(0))
            add_bias.out_port(0).connect(add_op.in_port(1))
            last_port = add_op.out_port(0)

        op.in_port(0).disconnect()
        op.out_port(0).get_connection().set_source(last_port)
示例#11
0
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = int64_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name

    shape_node = Shape(graph,
                       dict(name=split_node_name + '/Shape_')).create_node()
    scales_node = Const(graph,
                        dict(name=split_node_name + '/scales_',
                             value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    slice_begin = Const(
        graph,
        dict(name=split_node_name + '/slice_begin_',
             value=int64_array([axis]))).create_node()
    slice_end = Const(
        graph,
        dict(name=split_node_name + '/slice_end_',
             value=int64_array([axis + 1]))).create_node()

    strided_slice_node = StridedSlice(
        graph, {
            'name': split_node_name + '/StridedSlice_',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0]),
        }).create_node([shape_node, slice_begin, slice_end])
    strided_slice_node.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(
        graph,
        dict(name=split_node_name + '/Interpolate_',
             axes=int64_array([axis]),
             mode='nearest')).create_node()
    mul_node.out_port(0).connect(interp_node.in_port(1))

    match['concat'].out_port(0).get_connection().set_source(
        interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def replace_pattern(graph: Graph, match: [str, Node]):
        node = match['sub']

        # Add new nodes
        negate_const = Const(
            graph, dict(name=node.name + '/negate_const',
                        value=np.array(-1))).create_node()
        negate = Mul(graph, {'name': node.name + '/negate_'}).create_node()
        add = Add(graph, {'name': node.name + '/add_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(negate.in_port(0))
        negate_const.out_port(0).connect(add.in_port(1))
        node.in_port(0).get_connection().set_destination(add.in_port(1))
        negate.out_port(0).connect(add.in_port(0))

        node.out_port(0).get_connection().set_source(add.out_port(0))
    def replace_pattern(self, graph: Graph, match: dict):
        quantize = match['quantize']

        sum_node = Add(graph, dict()).create_node()
        const = Const(graph, {'value': np.array(0.5)}).create_node()
        mul_node = Mul(graph, dict()).create_node()

        mul_node.in_port(0).connect(sum_node.out_port(0))
        mul_node.in_port(1).connect(const.out_port(0))

        quantize.in_port(1).get_connection().get_source().connect(
            sum_node.in_port(0))
        quantize.in_port(2).get_connection().get_source().connect(
            sum_node.in_port(1))

        quantize.in_port(1).disconnect()
        quantize.in_port(2).disconnect()

        mul_node.out_port(0).connect(quantize.in_port(1))
        mul_node.out_port(0).connect(quantize.in_port(2))
示例#14
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                sub = Sub(graph, {'name': node_name + '/Sub'}).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
    def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(graph, dict(
            name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)),
            dst_type=dst_type, stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {'name': name + '/output_range'}).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {'name': name + '/descaled_output_low'}).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/zero_point'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(shift.out_port(0))

        mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
    def replace_pattern(graph: Graph, match: dict):
        log.debug(
            '================== GNMTBeforeConditionFind ==================')
        input_sequence_lengths = match['Max'].in_port(0).get_source().node
        encoder_sequence_lengths = looking_for_op_in_list([
            port.node
            for port in input_sequence_lengths.out_port(0).get_destinations()
        ], 'Identity')

        # Looking for Sequence_length node in encoder looks like:
        # Sequence_length -> CheckSeqLen -> Max -> Maximum -> Minimum

        check_seq_len = looking_for_op_in_list([
            port.node for port in encoder_sequence_lengths.out_port(
                0).get_destinations()
        ], 'Identity')
        max = looking_for_op_in_list([
            port.node for port in check_seq_len.out_port(0).get_destinations()
        ], 'ReduceMax')
        maximum = max.out_port(0).get_destinations()[0].node
        assert maximum.op == 'Maximum'
        minimum = maximum.out_port(0).get_destinations()[0].node
        assert minimum.op == 'Minimum'

        tensor_seq_len = looking_for_op_in_list([
            minimum.in_port(port).get_source().node
            for port in minimum.in_ports()
        ], 'StridedSlice')

        # Create node for multiplying seq_len by 2
        const = Const(graph, {
            'name': 'FakeSeqLenMultiplyer',
            'value': np.array(2)
        }).create_node()
        mul_op = Mul(graph, {'name': 'FakeSeqLen'}).create_node()

        const.out_port(0).get_connection().set_destination(mul_op.in_port(1))
        tensor_seq_len.out_port(0).get_connection().add_destination(
            mul_op.in_port(0))

        # Connect seq_len * 2 to TensorArray from GNMT loop
        ta_writes = [
            port.node
            for port in match['Identity_1'].out_port(0).get_destinations()
            if port.node.op == 'TensorArrayWriteV3'
        ]

        for ta_write in ta_writes:
            ta = ta_write.in_port(0).get_source().node.in_port(
                0).get_source().node

            ta.in_port(0).disconnect()
            ta.in_port(0).get_connection().set_source(mul_op.out_port(0))
    def replace_op(self, graph: Graph, node: Node):
        # Create nodes
        const_neg = Const(
            graph, dict(value=np.array(-1),
                        name=node.name + '/negate_const_')).create_node()
        negate = Mul(graph, {'name': node.name + '/negate_'}).create_node()
        add = Add(graph, {'name': node.name + '/add_'}).create_node()

        const = Const(graph, {'value': np.array(2)}).create_node()
        squared = Pow(graph, {'name': node.name + '/squared_'}).create_node()

        # Connect nodes
        node.in_port(0).get_connection().set_destination(add.in_port(0))
        node.in_port(1).get_connection().set_destination(negate.in_port(0))
        const_neg.out_port(0).connect(negate.in_port(1))
        negate.out_port(0).connect(add.in_port(1))
        add.out_port(0).connect(squared.in_port(0))
        const.out_port(0).connect(squared.in_port(1))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [squared.id]
    def replace_op(self, graph: Graph, node: Node):
        name = node.soft_get('name', node.id)

        # create range of axes for MVN based on `start_axis` and rank of input
        rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(2),
            2: int64_array(1)
        }, {
            'name': name + '/Range',
            'output_type': np.int64
        })
        mvn = MVN(
            graph, {
                'eps': node.epsilon,
                'eps_mode': 'inside_sqrt',
                'normalize_variance': 1,
                'name': name + '/Ins_Norm/MVN_',
            }).create_node()
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        rng.out_port(0).connect(mvn.in_port(1))
        mul = Mul(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/mul_'
        }).create_node()
        mvn.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        add = Add(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/add_'
        }).create_node()
        mul.out_port(0).connect(add.in_port(0))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.in_port(0).get_connection().add_destination(rank.in_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        rename_nodes([(node, name + '/TBD'), (add, name)])

        return [add.id]
示例#19
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        rename_node(node, node_name + '/TBR')
        sqr_node = Mul(graph, {}).create_node()
        reduce_sum_node = ReduceSum(
            graph, {
                'keep_dims': node.soft_get('keep_dims', 0),
                'axis': node.soft_get('axis', None)
            }).create_node()
        sqrt_node = create_op_with_const_inputs(graph, Pow,
                                                {1: float_array(0.5)})
        rename_node(sqrt_node, node_name)

        # Connect nodes
        node.in_port(0).get_connection().set_destination(sqr_node.in_port(0))
        sqr_node.in_port(0).get_connection().add_destination(
            sqr_node.in_port(1))
        sqr_node.out_port(0).connect(reduce_sum_node.in_port(0))
        reduce_sum_node.out_port(0).connect(sqrt_node.in_port(0))

        return [sqrt_node.id]
示例#20
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {'out_ports_count': 2}, shape)
        node_pool_stride = Const(graph, {
            'value': int64_array([node.pool_stride])
        }).create_node()
        pow_node = create_op_node_with_second_input(graph, Pow,
                                                    int64_array([-1]))
        pow_node.in_port(0).connect(node_pool_stride.out_port(0))

        mul = Mul(graph, {}).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        const_1 = Const(graph, {'value': int64_array([1])}).create_node()

        concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(3).connect(mul.out_port(0))
        concat.in_port(2).connect(const_1.out_port(0))
        concat.in_port(1).connect(node_pool_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': '/Reshape/' + node.name
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node.name + '/Reshape/'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['normalize']

        # rename normalize node since it will be no longer output node after the transformation
        output_name = node.soft_get('name', node.id)
        normalizel2_name = output_name + '/normalizel2'
        rename_node(node, normalizel2_name)

        assert node.in_port(0).data.get_shape().size in [2, 3, 4]
        assert node.has_valid('across_spatial')
        assert node.has_valid('channel_shared')
        assert node.has_valid('eps')

        if 'bin' in node.in_edge(1):
            del node.in_edge(1)['bin']

        weights = node.in_port(1).data.get_value()
        assert weights is not None
        # in the code below we intentionally use get_source() to get the out port. Because updating the out port will
        # update the Const node 'value' and 'shape' attributes
        if node.channel_shared or all(weights == weights[0]):
            node.in_port(1).get_source().data.set_value(np.array([weights[0]]))
        else:
            new_shape = np.ones((len(node.in_port(0).data.get_shape())),
                                dtype=np.int64)
            new_shape[1] = -1
            node.in_port(1).get_source().data.set_value(
                np.array(weights).reshape(new_shape))

        mul = Mul(graph, {'name': output_name}).create_node()
        rename_node(mul, output_name)

        if not node.across_spatial:
            axes = int64_array([1])
        else:
            axes = int64_array(
                np.arange(start=1, stop=node.in_port(0).data.get_shape().size))

        normalizel2 = create_op_with_const_inputs(graph, NormalizeL2Op,
                                                  {1: axes}, {
                                                      'eps_mode': 'add',
                                                      'eps': node.eps
                                                  })

        node.out_port(0).get_connection().set_source(mul.out_port(0))
        node.in_port(1).get_connection().get_source().connect(mul.in_port(1))
        normalizel2.out_port(0).connect(mul.in_port(0))
        node.in_port(0).get_connection().set_destination(
            normalizel2.in_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ThresholdedRelu'):
            name = node.soft_get('name', node.id)

            greater = create_op_with_const_inputs(graph, Greater, {1: float_array([node.alpha])})
            greater.in_port(0).connect(node.in_port(0).get_source())
            float_greater = Cast(graph,
                                 {'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
            greater.out_port(0).connect(float_greater.in_port(0))

            mul = Mul(graph, {}).create_node()
            node.out_port(0).get_connection().set_source(mul.out_port(0))
            mul.in_port(0).connect(node.in_port(0).get_source())
            mul.in_port(1).connect(float_greater.out_port(0))

            rename_nodes([(node, name + '/TBR'), (mul, name)])
    def replace_pattern(graph: Graph, match: [str, Node]):
        node = match['div']
        power_of_exponent = Const(graph, {
            'value': np.float64(-1)
        }).create_node()
        reciprocal = Pow(graph, {
            'name': node.name + '/reciprocal_'
        }).create_node()
        mul = Mul(graph, {'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        power_of_exponent.out_port(0).connect(reciprocal.in_port(1))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        node.out_port(0).get_connection().set_source(mul.out_port(0))
示例#24
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== GNMTBeforeConditionFind ==================')
        input_sequence_lengths = match['Max'].in_port(0).get_source().node
        encoder_sequence_lengths = looking_for_op_in_list([port.node for port in input_sequence_lengths.out_port(0).get_destinations()],
                                                          'Identity')

        # Looking for Sequence_length node in encoder looks like:
        # Sequence_length -> CheckSeqLen -> Max -> Maximum -> Minimum

        check_seq_len = looking_for_op_in_list([port.node for port in encoder_sequence_lengths.out_port(0).get_destinations()],
                                               'Identity')
        max = looking_for_op_in_list([port.node for port in check_seq_len.out_port(0).get_destinations()], 'ReduceMax')
        maximum = max.out_port(0).get_destinations()[0].node
        assert maximum.op == 'Maximum'
        minimum = maximum.out_port(0).get_destinations()[0].node
        assert minimum.op == 'Minimum'

        tensor_seq_len = looking_for_op_in_list([minimum.in_port(port).get_source().node for port in minimum.in_ports()], 'StridedSlice')

        # Create node for multiplying seq_len by 2
        const = Const(graph, {'name': 'FakeSeqLenMultiplyer', 'value': np.array(2)}).create_node()
        mul_op = Mul(graph, {'name': 'FakeSeqLen'}).create_node()

        const.out_port(0).get_connection().set_destination(mul_op.in_port(1))
        tensor_seq_len.out_port(0).get_connection().add_destination(mul_op.in_port(0))

        # Connect seq_len * 2 to TensorArray from GNMT loop
        ta_writes = [port.node for port in match['Identity_1'].out_port(0).get_destinations() if port.node.op == 'TensorArrayWriteV3']

        for ta_write in ta_writes:
            ta = ta_write.in_port(0).get_source().node.in_port(0).get_source().node

            ta.in_port(0).disconnect()
            ta.in_port(0).get_connection().set_source(mul_op.out_port(0))

        if not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method", extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True
    def replace_pattern(graph: Graph, match: dict):
        node = match['normalize']
        assert node.in_port(0).data.get_shape().size in [2, 3, 4]
        assert node.has_valid('across_spatial')
        assert node.has_valid('channel_shared')
        assert node.has_valid('eps')

        if 'bin' in node.in_edge(1):
            del node.in_edge(1)['bin']

        weights = node.in_port(1).data.get_value()
        assert weights is not None
        # in the code below we intentionally use get_source() to get the out port. Because updating the out port will
        # update the Const node 'value' and 'shape' attributes
        if node.channel_shared or all(weights == weights[0]):
            node.in_port(1).get_source().data.set_value(np.array([weights[0]]))
        else:
            new_shape = np.ones((len(node.in_port(0).data.get_shape())),
                                dtype=np.int64)
            new_shape[1] = -1
            node.in_port(1).get_source().data.set_value(
                np.array(weights).reshape(new_shape))

        mul = Mul(graph, {
            'name': node.name + '/Normalize_weights_multiplication'
        }).create_node()
        node.out_port(0).get_connection().set_source(mul.out_port(0))
        node.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().get_source().connect(mul.in_port(1))
        node.in_port(1).disconnect()

        node['type'] = 'NormalizeL2'
        node['eps_mode'] = 'add'
        node['force_precision_in_ports'] = {1: 'int64'}

        axes_val = np.array([1]) if not node.across_spatial else \
            np.arange(start=1, stop=node.in_port(0).data.get_shape().size)
        axes = Const(graph, {'value': axes_val}).create_node()
        node.in_port(1).connect(axes.out_port(0))

        del node['across_spatial']
        del node['channel_shared']
示例#26
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {
            'name': shape.name + '/split_batch',
            'out_ports_count': 2
        }, shape)

        pow_node = create_op_node_with_second_input(
            graph, Pow, int64_array([-1]),
            {'name': node_name + '/patch_stride/inverse'})
        conv_patch_stride = Const(
            graph, {
                'value': int64_array([node.patch_stride]),
                'name': node_name + '/patch_stride/'
            }).create_node()
        pow_node.in_port(0).connect(conv_patch_stride.out_port(0))

        mul = Mul(graph, {
            'name': node_name + '/mul_inverse_stride_h'
        }).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {2: int64_array([1])}, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })

        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(1).connect(mul.out_port(0))
        concat.in_port(3).connect(conv_patch_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
    def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
        split_node = Split(graph, {
            'name': 'Split_lstm_input_',
            'num_splits': 5
        }).create_node()
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        split_node.in_port(1).connect(split_node_axis.out_port(0))

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False}
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False}
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False}
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': 'prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]
示例#28
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert scales.shape == (4, )
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        factor = Const(graph, {
            'value': np.array([height_scale, width_scale])
        }).create_node()

        shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()

        layout = graph.graph['layout']
        if input_shape_rank == 4:
            begin = Const(graph, {
                'value':
                int64_array([get_height_dim(layout, input_shape_rank)])
            }).create_node()
        else:
            begin = Const(graph, {
                'value':
                int64_array([get_depth_dim(layout, input_shape_rank)])
            }).create_node()
        end = Const(graph, {
            'value':
            int64_array([get_width_dim(layout, input_shape_rank) + 1])
        }).create_node()

        stride = Const(graph, {'value': int64_array([1])}).create_node()
        ss = StridedSlice(
            graph, {
                'name': upsample.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            }).create_node()

        mul = Mul(graph, {
            'name': upsample.name + '/factor_mul_'
        }).create_node()

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))
        ss.out_port(0).connect(mul.in_port(0))
        factor.out_port(0).connect(mul.in_port(1))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name='Interpolate/{}'.format(upsample.name),
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))
示例#29
0
    def replace_pattern(self, graph: Graph, match: dict):

        merge = match['merge']
        power = Pow(graph, {
            'name': merge.name + '/reciprocal_',
            'type': 'PNORM'
        }).create_node()
        const1 = Const(graph, {
            'value': -1.0,
            'name': merge.name + '/negate_const'
        }).create_node()
        merge.in_port(0).get_connection().set_destination(power.in_port(0))
        const1.out_port(0).connect(power.in_port(1))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(merge.significant, merge.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(10.0, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            power.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(10.0, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = Cast(
                graph, {
                    'name': merge.name + '/Cast_' + idx.__str__(),
                    'dst_type': np.float32
                }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': False,
                'in_ports_count': 2,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array([1, 5]),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
示例#30
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        quantize = match['quantize']

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)
        assert len(output_low.out_nodes()) == 1
        assert len(output_high.out_nodes()) == 1

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        operator = match['operator']

        weights = operator.in_node(1).value
        weights_rounded = np.round(weights)
        weights_consistent = np.all(np.isclose(weights, weights_rounded)) and \
                             set(np.unique(weights_rounded)).issubset({-1, 1})

        if weights_consistent and np.all(np.isclose(output_low, 0)) and np.all(
                np.isclose(output_high, 1)):
            reduction_indices = set(range(len(weights.shape))) - set(
                [operator.output_feature_channel])
            weights_reduced = np.add.reduce(weights,
                                            axis=tuple(reduction_indices))
            weights_reduced = weights_reduced.reshape(
                [len(weights_reduced), 1, 1])  # FIXME: works for NCHW only

            add_term = Const(graph, {'value': weights_reduced}).create_node()
            add = Add(graph, {}).create_node()
            add.in_port(1).connect(add_term.out_port(0))
            mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
            mul = Mul(graph, {}).create_node()
            mul.in_port(1).connect(mul_term.out_port(0))
            add.out_port(0).connect(mul.in_port(0))

            operator.out_port(0).get_connection().set_source(mul.out_port(0))
            add.in_port(0).connect(operator.out_port(0))

            operator['pad_value'] = float(-1.0)
        elif weights_consistent and np.all(np.isclose(
                output_low, -1)) and np.all(np.isclose(output_high, +1)):
            pass
        else:
            log.debug(
                'ConvToBinaryConv: cannot apply transformation because input range is neither in [0, +1] nor '
                'in [-1, +1].')
            return

        operator['type'] = 'BinaryConvolution'
        operator['mode'] = 'xnor-popcount'
        operator['pad_value'] = operator.soft_get('pad_value', float(0))
        operator['input'] = operator.in_node(0).shape[1]
        # Weights are not bit-packed yet; there should be a separate transformation to do that

        assert output_low.size == 1
        assert output_high.size == 1

        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        # Make sure that low/high values are exactly 0/1
        output_low.value = np.zeros(output_low.shape)
        output_high.value = np.ones(output_high.shape)