示例#1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        op = match['op']
        out_port = op.in_port(0).get_source()

        if op.soft_get('scale', 1) != 1:
            const = Const(graph, {'value': np.array(op.scale)}).create_node()
            mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
            const.out_port(0).connect(mul.in_port(1))
            out_port.connect(mul.in_port(0))
            out_port = mul.out_port(0)

        if op.soft_get('shift', 0) != 0:
            const = Const(graph, {'value': np.array(op.shift)}).create_node()
            add = Add(graph, {'name': op.name + '/add_'}).create_node()
            const.out_port(0).connect(add.in_port(1))
            out_port.connect(add.in_port(0))
            out_port = add.out_port(0)

        if op.soft_get('power', 1) != 1:
            const = Const(graph, {'value': np.array(op.power)}).create_node()
            pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
            const.out_port(0).connect(pow.in_port(1))
            out_port.connect(pow.in_port(0))
            out_port = pow.out_port(0)

        op.out_port(0).get_connection().set_source(out_port)
    def resolve_minus3(self, shape_node, input_index, reshape_index, dims):
        shape_indexes_node1 = Const(
            shape_node.graph,
            dict(name=shape_node.id + '/ShapeMinus3_index_const1_' +
                 str(input_index),
                 value=int64_array([input_index]))).create_node()
        dims_node1 = get_shape_values_by_indices_node(shape_node,
                                                      shape_indexes_node1)

        shape_indexes_node2 = Const(
            shape_node.graph,
            dict(name=shape_node.id + '/ShapeMinus3_index_const2_' +
                 str(input_index),
                 value=int64_array([input_index + 1]))).create_node()
        dims_node2 = get_shape_values_by_indices_node(shape_node,
                                                      shape_indexes_node2)

        mul_node = Mul(
            shape_node.graph,
            dict(name=shape_node.id + '/MulMinus3_' +
                 str(input_index))).create_node()

        mul_node.in_port(0).connect(dims_node1.out_port(0))
        mul_node.in_port(1).connect(dims_node2.out_port(0))

        input_index = input_index + 2
        reshape_index = reshape_index + 1
        return input_index, reshape_index, dims, mul_node
 def __insert_mul_node_with_coeff(node: Node, port: int, coeff: float):
     if coeff != 1:
         mul_node = Mul(node.graph, {
             'name': node.id + '/coeff_mul'
         }).create_node()
         const_node = Const(node.graph, {
             'name': node.id + '/coeff',
             'value': mo_array([coeff])
         }).create_node()
         node.in_port(port).get_connection().insert_node(mul_node)
         const_node.out_port(0).connect(mul_node.in_port(1))
示例#4
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {'can_be_fused': False}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                # its is necessary not to replace subrtract for pattern in offline transformations
                # See ConvertQuantizeDequantize transformation in ngraph
                sub = Sub(graph, {
                    'name': node_name + '/Sub',
                    'zero_point_sub': True
                }).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
示例#5
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['softmax']
     if 'temperature' in node and node['temperature'] != 1.0:
         in_node = node.in_node()
         out_nodes = [node for node in node.out_nodes().values()]
         graph.remove_edge(node.in_node().id, node.id)
         temperature = mo_array([1.0 / node.temperature])
         scalar_value_op = Const(graph, dict(value=temperature, shape=temperature.shape,
                                             symbol_dict={'name': node.id + '/const'}))
         mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'}))
         mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()])
         edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0]
         graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
示例#6
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5
                or node.in_node(0).value is not None or  # input
                node.in_node(1).value is None or  # scale
                node.in_node(2).value is None or  # offset
                node.in_node(3).value is not None or  # mean
                node.in_node(4).value is not None or  # variance
                node.in_node(1).value.ndim != 1 or
                node.in_node(2).value.ndim != 1):
            return

        scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_'))
        shift_add = Add(graph, dict(name=node.name + '/shift_add_'))
        mean_add = Add(graph, dict(name=node.name + '/mean_add_'))
        variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_'))

        neg_const = Const(
            graph, dict(value=np.array(-1), name=node.name + '/mean_negate_'))
        mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_'))
        mean_arg = mean_add.create_node_with_data([
            node.in_node(0),
            mean_negate.create_node_with_data(
                [node.in_node(3),
                 neg_const.create_node_with_data()])
        ])

        shift_const = Const(
            graph,
            dict(value=node.eps,
                 name=node.name + '/variance_denom_shift_const_'))
        power_const = Const(
            graph,
            dict(value=-0.5, name=node.name + '/variance_denom_power_const_'))
        variance_denom_shift = Add(
            graph, dict(name=node.name + '/variance_denom_shift_'))
        variance_denom_power = Pow(
            graph, dict(name=node.name + '/variance_denom_power_'))
        variance_arg = variance_mul.create_node_with_data([
            mean_arg,
            variance_denom_power.create_node_with_data([
                variance_denom_shift.create_node_with_data(
                    [node.in_node(4),
                     shift_const.create_node_with_data()]),
                power_const.create_node_with_data()
            ])
        ])

        shift_add.create_node_with_data([
            scale_mul.create_node_with_data([variance_arg,
                                             node.in_node(1)]),
            node.in_node(2)
        ],
                                        data_nodes=node.out_node())

        node.graph.remove_node(node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['normalize']

        # rename normalize node since it will be no longer output node after the transformation
        output_name = node.soft_get('name', node.id)
        normalizel2_name = output_name + '/normalizel2'
        rename_node(node, normalizel2_name)

        assert node.in_port(0).data.get_shape().size in [2, 3, 4]
        assert node.has_valid('across_spatial')
        assert node.has_valid('channel_shared')
        assert node.has_valid('eps')

        if 'bin' in node.in_edge(1):
            del node.in_edge(1)['bin']

        weights = node.in_port(1).data.get_value()
        assert weights is not None
        # in the code below we intentionally use get_source() to get the out port. Because updating the out port will
        # update the Const node 'value' and 'shape' attributes
        if node.channel_shared or all(weights == weights[0]):
            node.in_port(1).get_source().data.set_value(np.array([weights[0]]))
        else:
            new_shape = np.ones((len(node.in_port(0).data.get_shape())),
                                dtype=np.int64)
            new_shape[1] = -1
            node.in_port(1).get_source().data.set_value(
                np.array(weights).reshape(new_shape))

        mul = Mul(graph, {'name': output_name}).create_node()
        rename_node(mul, output_name)

        if not node.across_spatial:
            axes = int64_array([1])
        else:
            axes = int64_array(
                np.arange(start=1, stop=node.in_port(0).data.get_shape().size))

        normalizel2 = create_op_with_const_inputs(graph, NormalizeL2Op,
                                                  {1: axes}, {
                                                      'eps_mode': 'add',
                                                      'eps': node.eps
                                                  })

        node.out_port(0).get_connection().set_source(mul.out_port(0))
        node.in_port(1).get_connection().get_source().connect(mul.in_port(1))
        normalizel2.out_port(0).connect(mul.in_port(0))
        node.in_port(0).get_connection().set_destination(
            normalizel2.in_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='LayerNorm'):
            node_name = node.soft_get('name', node.id)

            if node.output_mean_var is True:
                if not node.out_port(1).disconnected() or not node.out_port(2).disconnected():
                    raise Error("Node {} is supported with only one output".format(node_name))
                log.error('LayerNorm node {} with attribute "output_mean_var" = True is not supported.'
                          'But since the node has one output, the conversion will continue.'.format(node_name),
                          extra={'is_warning': True})

            input_shape = node.in_port(0).data.get_shape()
            assert node.has_valid('axis'), 'Incorrect axis value for the node {}'.format(node_name)
            axis = node.axis

            mvn = create_op_node_with_second_input(graph, MVN, int64_array([axis]),
                                                   dict(eps=node.epsilon, name=node_name + '/LayerNorm/MVN_',
                                                        across_channels=1, normalize_variance=1, eps_mode='inside_sqrt'))

            mul = Mul(graph, {'name': node_name + '/LayerNorm/mul_'}).create_node()
            add = Add(graph, {'name': mul.name + '/LayerNorm/add_'}).create_node()

            node.in_port(0).get_connection().set_destination(mvn.in_port(0))
            node.in_port(1).get_connection().set_destination(mul.in_port(1))
            node.in_port(2).get_connection().set_destination(add.in_port(1))

            mvn.out_port(0).connect(mul.in_port(0))
            mul.out_port(0).connect(add.in_port(0))
            node.out_port(0).get_connection().set_source(add.out_port(0))

            # MXNet LayerNorm gamma and beta attributes are 1D tensors with shape = [input_shape[axis]]
            # We have to unsqueeze values for Mul and Add operations to avoid shapes incompatibility problems
            # if axis != -1
            canonical_axis = get_canonical_axis_index(input_shape, axis)
            unsqueeze_value = []
            for idx, val in enumerate(input_shape):
                if idx != canonical_axis:
                    unsqueeze_value.append(idx)

            mul_const_unsqueeze = create_op_node_with_second_input(graph, Unsqueeze,
                                                                   int64_array(unsqueeze_value),
                                                                   dict(name=mul.name + '/Unsqueeze',
                                                                        override_output_shape=True))
            add_const_unsqueeze = create_op_node_with_second_input(graph, Unsqueeze,
                                                                   int64_array(unsqueeze_value),
                                                                   dict(name=add.name + '/Unsqueeze',
                                                                        override_output_shape=True))

            mul.in_port(1).get_connection().insert_node(mul_const_unsqueeze)
            add.in_port(1).get_connection().insert_node(add_const_unsqueeze)

            rename_nodes([(node, node_name + '/ShouldBeDeleted'), (add, node_name)])
示例#9
0
    def div_to_mul_replacement(div: Node):
        # we execute this transformation for V10 IR later on middle phase despite graph_condition
        # so we prevent Div replacement on shape-calculating sub-graphs
        if div.in_port(0).data.get_value() is not None and div.in_port(1).data.get_value() is not None:
            return

        # cannot replace Div with Mul when the divisor is integer because the reciprocal number will be 0
        value = div.in_port(1).data.get_value()
        if value is not None and type(value.item(0)) == int:
            return

        graph = div.graph
        name = div.soft_get('name', div.id)

        # keep Mul name the same as Div -- because of mathematical equality of output tensors
        rename_node(node=div, name=name + '/to_be_removed')

        # reconnect Div in(out)puts to Mul
        mul = Mul(graph, {'name': name}).create_node()
        rename_node(mul, name)

        div.in_port(0).get_connection().set_destination(mul.in_port(0))
        div.in_port(1).get_connection().set_destination(mul.in_port(1))
        div.out_port(0).get_connection().set_source(mul.out_port(0))

        # restore mathematical equivalence to Div operation: Div(A, B) = Mul(A, Pow(B, -1))
        reciprocal = create_op_with_const_inputs(graph, Pow, {1: np.float64(-1)}, {'name': name + '/reciprocal_'})
        mul.in_port(1).get_connection().insert_node(reciprocal)
示例#10
0
    def replace_op(self, graph: Graph, node: Node):
        ss_node = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'name': 'Split_eltwise_' + node.name,
                                                                                  'num_splits': node['num_inputs']})

        inp = node.get_inputs()
        in_node = inp[0][0]
        edge_attrs = inp[0][1]
        graph.add_edge(in_node, ss_node.id, **edge_attrs)
        if ss_node.num_splits == 2:
            if node['operation'] == 'mul':
                eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
            elif node['operation'] == 'sum':
                eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
            else:
                raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
        elif ss_node.num_splits > 2:
            eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
                                                  'operation': node['operation']}).create_node()
        else:
            raise Error('Error on replacing Kaldi eltwise')
        for i in range(ss_node.num_splits):
            ss_node.out_port(i).get_connection().set_destination(eltwise_node.in_port(i))
        return [eltwise_node.id]
示例#11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(
            input.op, input.name))
        if input.id != match['mean'].in_node(
                0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(
            input.op, input.name))

        mvn = MVN(
            graph,
            dict(name=fbn.name + '/MVN_',
                 eps=fbn.eps,
                 eps_mode='outside_sqrt',
                 normalize_variance=1))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]), input_beta
        ])
        fbn.replace_node(new_subgraph)
示例#12
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # This replacer replace ImageScalar operation to Mul->Add sequence
        # Also it check that weights and biases are good
        op = match['op']

        # Check that weights and biases are not useless
        has_bias, has_weights = True, True
        if all([x == 1 for x in np.nditer(op.scale)]):
            has_weights = False
        if all([x == 0 for x in np.nditer(op.bias)]):
            has_bias = False

        assert len(op.in_ports()) == 1

        last_port = op.in_port(0).get_source()

        # Create Mul & Add nodes
        if has_weights:
            mul_weights = Const(graph,
                                dict(value=op.scale,
                                     shape=op.scale.shape)).create_node()
            mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
            op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
            mul_weights.out_port(0).connect(mul_op.in_port(1))
            last_port = mul_op.out_port(0)

        if has_bias:
            add_bias = Const(graph, dict(value=op.bias,
                                         shape=op.bias.shape)).create_node()
            add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
            last_port.get_connection().set_destination(add_op.in_port(0))
            add_bias.out_port(0).connect(add_op.in_port(1))
            last_port = add_op.out_port(0)

        op.in_port(0).disconnect()
        op.out_port(0).get_connection().set_source(last_port)
    def replace_pattern(self, graph: Graph, match: dict):
        quantize = match['quantize']

        sum_node = Add(graph, dict()).create_node()
        const = Const(graph, {'value': mo_array(0.5)}).create_node()
        mul_node = Mul(graph, dict()).create_node()

        mul_node.in_port(0).connect(sum_node.out_port(0))
        mul_node.in_port(1).connect(const.out_port(0))

        quantize.in_port(1).get_connection().get_source().connect(sum_node.in_port(0))
        quantize.in_port(2).get_connection().get_source().connect(sum_node.in_port(1))

        quantize.in_port(1).disconnect()
        quantize.in_port(2).disconnect()

        mul_node.out_port(0).connect(quantize.in_port(1))
        mul_node.out_port(0).connect(quantize.in_port(2))
    def replace_op(self, graph: Graph, node: Node):
        axis = Const(graph, {'value': int64_array([-1])}).create_node()
        mvn = MVN(
            graph,
            dict(name=node.name + '/mvn',
                 eps=node.module.eps,
                 normalize_variance=True,
                 eps_mode='inside_sqrt')).create_node([node.in_node(0), axis])

        weight = node.module.weight.detach().numpy()
        bias = node.module.bias.detach().numpy()

        w = Const(graph, {'value': weight}).create_node()
        b = Const(graph, {'value': bias}).create_node()
        mul = Mul(graph, dict(name=node.name + '/mul')).create_node([mvn, w])
        add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b])
        return [add.id]
示例#15
0
    def replace_op(self, graph: Graph, node: Node):
        if node.module.inverse:
            axes = Const(
                graph, {
                    'value': int64_array(range(2, node.module.num_axes - 1))
                }).create_node()
            dft_node = IDFT(graph, dict(name=node.name,
                                        in_ports_count=2)).create_node(
                                            [node.in_node(0), axes])

            # Slice a real part
            begin_id = Const(graph, {
                'value': int64_array([0, 0])
            }).create_node()
            end_id = Const(graph, {'value': int64_array([0, 1])}).create_node()
            real = StridedSlice(
                graph,
                dict(name=node.name + '/real',
                     begin_mask=[0, 0],
                     end_mask=[0, 1],
                     shrink_axis_mask=[0, 0],
                     new_axis_mask=[0],
                     ellipsis_mask=[1, 0])).create_node(
                         [dft_node, begin_id, end_id])

            squeeze_axis = Const(graph, {'value': -1}).create_node()
            res = Squeeze(graph,
                          dict(name=node.name + '/squeeze')).create_node(
                              [real, squeeze_axis])

            return [res.id]
        else:
            zero = Const(graph, {'value': 0.0}).create_node()
            imag = Mul(graph, dict(name=node.name + '/imag')).create_node(
                [node.in_node(0), zero])
            cmplx = PackOp(graph,
                           dict(name=node.name + '/complex',
                                axis=-1)).create_node([node.in_node(0), imag])

            axes = Const(graph, {
                'value': int64_array(range(2, node.module.num_axes))
            }).create_node()
            dft_node = DFT(graph,
                           dict(name=node.name,
                                in_ports_count=2)).create_node([cmplx, axes])
            return [dft_node.id]
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = float32_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def replace_op(self, graph: Graph, node: Node):
        mean = node.module.running_mean.detach().numpy()
        var = node.module.running_var.detach().numpy()
        weight = node.module.weight.detach().numpy()
        bias = node.module.bias.detach().numpy()

        w = weight / np.sqrt(var + node.module.eps)
        b = bias - w * mean

        shape = np.ones(node.module.dims, dtype=np.int32)
        shape[1] = -1  # channels

        w = Const(graph, {'value': w.reshape(shape)}).create_node()
        b = Const(graph, {'value': b.reshape(shape)}).create_node()
        mul = Mul(graph, dict(name=node.name + '/mul')).create_node(
            [node.in_node(0), w])
        add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b])
        return [add.id]
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ThresholdedRelu'):
            name = node.soft_get('name', node.id)

            greater = create_op_with_const_inputs(
                graph, Greater, {1: float_array([node.alpha])})
            greater.in_port(0).connect(node.in_port(0).get_source())
            float_greater = Cast(
                graph, {
                    'dst_type':
                    data_type_str_to_np(graph.graph['cmd_params'].data_type)
                }).create_node()
            greater.out_port(0).connect(float_greater.in_port(0))

            mul = Mul(graph, {}).create_node()
            node.out_port(0).get_connection().set_source(mul.out_port(0))
            mul.in_port(0).connect(node.in_port(0).get_source())
            mul.in_port(1).connect(float_greater.out_port(0))

            rename_nodes([(node, name + '/TBR'), (mul, name)])
            graph.remove_node(node.id)
    def replace_op(self, graph: Graph, node: Node):
        name = node.soft_get('name', node.id)

        # create range of axes for MVN based on `start_axis` and rank of input
        rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(2),
            2: int64_array(1)
        }, {
            'name': name + '/Range',
            'output_type': np.int64
        })
        mvn = MVN(
            graph, {
                'eps': node.epsilon,
                'eps_mode': 'inside_sqrt',
                'normalize_variance': 1,
                'name': name + '/Ins_Norm/MVN_',
            }).create_node()
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        rng.out_port(0).connect(mvn.in_port(1))
        mul = Mul(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/mul_'
        }).create_node()
        mvn.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        add = Add(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/add_'
        }).create_node()
        mul.out_port(0).connect(add.in_port(0))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.in_port(0).get_connection().add_destination(rank.in_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        rename_nodes([(node, name + '/TBD'), (add, name)])

        return [add.id]
示例#20
0
def parse_specifier(string, graph, layer_node_map):
    pos = string.find(b'(')
    if pos == -1:
        # node name
        input_name = str(string.split(b' ')[0]).strip('b').replace(
            "\'", '').replace('\\n', '')

        if input_name not in layer_node_map:
            node_name = graph.unique_id(prefix=input_name)
            graph.add_node(node_name, parameters=[], op="", kind='op')
            layer_node_map[input_name] = node_name
        else:
            node_name = layer_node_map[input_name]
        return node_name

    spec = string[:pos]
    args = get_args_for_specifier(string[pos:])
    if spec == b'Append':
        nodes = []
        for i in range(len(args)):
            nodes.append(parse_specifier(args[i], graph, layer_node_map))
        layer_name = 'Append_'
        for node in nodes:
            layer_name = layer_name + node + "_"

        if layer_name not in layer_node_map:
            concat_name = graph.unique_id(prefix=layer_name)
            graph.add_node(concat_name,
                           parameters=None,
                           op='concat',
                           kind='op')
            layer_node_map[layer_name] = concat_name
            i = 0
            Node(graph,
                 concat_name).add_sequence_of_ports('in', range(len(nodes)))
            for node in nodes:
                out_port = len(Node(graph, node).out_nodes())
                Node(graph, node).add_output_port(out_port)
                graph.create_edge(
                    Node(graph, node), Node(graph, concat_name), out_port, i,
                    create_edge_attrs(node, concat_name, node, i, out_port))
                i = i + 1
        else:
            concat_name = layer_node_map[layer_name]
        return concat_name
    elif spec == b'Offset':
        node = parse_specifier(args[0], graph, layer_node_map)
        t = int(args[1])
        if len(args) > 2:
            raise Error("ModelOptimizer supports only 2 arguments for Offset")
        layer_name = 'Offset_' + node + '_'
        if t < 0:
            layer_name = layer_name + '_' + str(-t)
        else:
            layer_name = layer_name + str(t)

        if layer_name not in layer_node_map:
            memory_name = graph.unique_id(prefix=layer_name)
            layer_node_map[layer_name] = memory_name
            memory_name_2 = memory_name + '_out'
            graph.add_node(memory_name,
                           parameters=dict(t=t,
                                           pair_name=memory_name_2,
                                           has_default=False),
                           op='MemoryOffset',
                           kind='op')
            out_port = len(Node(graph, node).out_nodes())
            in_port = len(Node(graph, memory_name).in_nodes())
            Node(graph, memory_name).add_input_port(in_port)
            Node(graph, node).add_output_port(out_port, skip_if_exist=True)
            graph.create_edge(
                Node(graph, node), Node(graph, memory_name), out_port, in_port,
                create_edge_attrs(node, memory_name, node, in_port, out_port))
        else:
            memory_name = layer_node_map[layer_name]
        return memory_name
    elif spec == b'Sum':
        nodes = []
        for i in range(len(args)):
            nodes.append(parse_specifier(args[i], graph, layer_node_map))

        layer_name = 'Sum_'
        for node in nodes:
            layer_name = layer_name + node + "_"

        if layer_name not in layer_node_map:
            sum_name = graph.unique_id(prefix=layer_name)
            graph.add_node(sum_name, parameters=None, op='Add', kind='op')
            layer_node_map[layer_name] = sum_name
        else:
            sum_name = layer_node_map[layer_name]

        for i, node in enumerate(nodes):
            out_port = len(Node(graph, node).out_nodes())
            Node(graph, node).add_output_port(out_port, skip_if_exist=True)
            Node(graph, sum_name).add_input_port(i)
            graph.add_edge(node, sum_name,
                           **create_edge_attrs(node, sum_name, node, i))

        return sum_name
    elif spec == b'IfDefined':
        node_id = parse_specifier(args[0], graph, layer_node_map)
        node = Node(graph, node_id)
        if node.op == 'MemoryOffset':
            node['parameters']['has_default'] = True
        return node_id
    elif spec == b'ReplaceIndex':
        node = parse_specifier(args[0], graph, layer_node_map)
        return node
    elif spec == b'Scale':
        node_name = parse_specifier(args[1], graph, layer_node_map)
        scale_value = float(args[0])
        layer_name = '{}/Mul/{}'.format(node_name, scale_value)

        if layer_name not in layer_node_map:
            scale_name = graph.unique_id(prefix=layer_name)
            scale_node = Mul(graph, {'name': scale_name}).create_node()

            layer_node_map[layer_name] = scale_name

            scale_const_name = 'Const_{}'.format(scale_value)
            const_node = Const(graph, {
                'name': scale_const_name,
                'value': float_array([scale_value])
            }).create_node()

            node = Node(graph, node_name)
            graph.create_edge(
                const_node, scale_node, 0, 0,
                create_edge_attrs(const_node.id, scale_node.id, const_node.id))
            out_port = len(node.out_nodes())
            graph.create_edge(
                node, scale_node, out_port, 1,
                create_edge_attrs(node_name, scale_node.id, node_name, 1,
                                  out_port))
        else:
            scale_name = layer_node_map[layer_name]

        return scale_name
示例#21
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if 1 not in node.in_ports() or node.in_port(1).disconnected():

            if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'):
                factor = Const(graph, {'value': np.array(node.factor)}).create_node()

                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(mul.in_port(0))
                factor.out_port(0).connect(mul.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                assert node.in_port(1).disconnected()
                mul.out_port(0).connect(node.in_port(1))

            else:
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))

                pads_value = node.pads_begin + node.pads_end
                pads_const = Const(graph, {'value': np.array(pads_value)}).create_node()
                add = Add(graph, {'name': node.name + '/pad_add'}).create_node()
                ss.out_port(0).connect(add.in_port(0))
                add.in_port(1).connect(pads_const.out_port(0))

                if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1:
                    shrink_factor = node.shrink_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'name': node.name + '/pre_shrink_sub_const',
                                          'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    sub.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'value': np.array(1 / shrink_factor),
                                          'name': node.name + 'shrink_factor_div_const'}).create_node()
                    div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    div.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1)
                                          }).create_node()
                    add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node()
                    div.out_port(0).connect(add.in_port(0))
                    const.out_port(0).connect(add.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    add.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1:
                    zoom_factor = node.zoom_factor
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \
                                            ' file (openvino/tools/mo/front/InterpolateNormalizer.py at the line {}).' \
                                            ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)

                    # Reshape methods can be different in some cases
                    # Commented out section represents reshape that used in deeplab-caffe
                    # Uncomment the following lines, if your model was trained with deeplab-caffe
                    # or have the same reshape method
                    # const = Const(graph, {'value': np.array(-1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node()
                    # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node()
                    # add.out_port(0).connect(sub.in_port(0))
                    # const.out_port(0).connect(sub.in_port(1))
                    #
                    # const = Const(graph, {'value': np.array(zoom_factor - 1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node()
                    # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node()
                    # sub.out_port(0).connect(mul.in_port(0))
                    # const.out_port(0).connect(mul.in_port(1))
                    #
                    # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node()
                    # add.out_port(0).connect(sum.in_port(0))
                    # mul.out_port(0).connect(sum.in_port(1))
                    #
                    # node.add_input_port(1, skip_if_exist=True)
                    # assert node.in_port(1).disconnected()
                    # sum.out_port(0).connect(node.in_port(1))

                    # Comment out the following lines if you use the reshape method from previous section
                    const = Const(graph, {'value': np.array(zoom_factor),
                                          'name': node.name + '/zoom_factor_mul_const'}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()

                    add.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    mul.out_port(0).connect(node.in_port(1))

                elif node.soft_get('width') != 0 and node.soft_get('height') != 0:
                    const = Const(graph, {'value': np.array([node.height, node.width])}).create_node()
                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    const.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1:
                    shrink_factor = node.shrink_factor
                    zoom_factor = node.zoom_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    const.out_port(0).connect(sub.in_port(1))

                    const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node()
                    div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    const.out_port(0).connect(div.in_port(1))

                    const = Const(graph, {'value': np.array(-1),
                                          'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node()
                    sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    const.out_port(0).connect(sum.in_port(1))

                    const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()
                    sum.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    mul.out_port(0).connect(sum.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    sum.out_port(0).connect(node.in_port(1))
        else:
            if node.soft_get('fw') == 'caffe':
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(1).get_connection().get_source()
                node.in_port(1).disconnect()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(node.in_port(1))
示例#22
0
 def extract(cls, node: Node):
     axis = onnx_attr(node, 'axis', 'i', default=None)
     Mul.update_node_stat(node, {'axis': axis})
     return cls.enabled
示例#23
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
        initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {'name': initial_shape_op_node.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))

        group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                        'name': group_norm_node.name + '/GroupSize'}).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]),
                                                   'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
                                                                initial_spatial_dims_node])
        new_shape_node = Cast(graph,
                              {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN',
                               'normalize_variance': 1,
                               'eps': group_norm_node.eps,
                               'eps_mode': 'inside_sqrt'}).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(),
                                                   return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)},
                                          {'name': group_norm_node.name + '/Range', 'output_type': np.int64})
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
示例#24
0
def _fused_batch_norm_decomposition(graph: Graph, tinput: Port, toutput: Port, gamma: Port, beta: Port,
                                    mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add sub graph
    """
    batch_norm_name = tinput.get_connection().get_destination().node.name

    # Create first Mul & Add operations
    mul1_node = Mul(graph, dict(name=batch_norm_name + "/mean", can_be_fused=can_be_fused)).create_node()
    add1_node = Add(graph, dict(name=batch_norm_name + "/variance", can_be_fused=can_be_fused)).create_node()

    const_mul1_node = Const(graph, dict(name="data_mul_", value=mo_array(mean))).create_node()
    const_add1_node = Const(graph, dict(name="data_add_", value=mo_array(variance))).create_node()

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]:
        value = gamma.data.get_value()
        value.resize(gamma.data.get_shape()).fill(value[0])
        gamma.data.set_value(value)

    # Create second Mul & Add
    mul2_node = Mul(graph, dict(name=batch_norm_name + "/gamma", can_be_fused=can_be_fused)).create_node()
    add2_node = Add(graph, dict(name=batch_norm_name + "/beta", can_be_fused=can_be_fused)).create_node()

    # Connect edges Mul1->Add1->Mul2->Add2
    tinput.get_connection().set_destination(mul1_node.in_port(0))
    mul1_node.in_port(1).get_connection().set_source(const_mul1_node.out_port(0))

    add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0))
    add1_node.in_port(1).get_connection().set_source(const_add1_node.out_port(0))

    mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0))
    gamma.get_connection().set_destination(mul2_node.in_port(1))

    add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0))
    beta.get_connection().set_destination(add2_node.in_port(1))

    toutput.get_connection().set_source(add2_node.out_port(0))
示例#25
0
def convert_scale_shift_to_mul_add(graph: Graph):
    nodes = graph.get_op_nodes(op='ScaleShift')
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        ports_count = len(node.in_ports())

        input_port = node.in_port(0)
        scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
        shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
        output_port = node.out_port(0)

        has_biases = True
        has_weights = True

        # We don't need zero biases
        if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
            has_biases = False

        # We don't need weights with ones
        if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
            has_weights = False

        mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
        add_op = Add(graph, dict(name=node.name + "/Add_"))

        # Expand dims for current layout
        broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0

        # In case if we have constant weights/biases we have to broadcast them according to graph layout
        # otherwise we insert Reshape with broadcast dim attribute.
        def broadcast_value(port):
            value = mo_array(port.data.get_value())
            for idx in range(broadcast_dims_cnt):
                value = np.expand_dims(value, axis=-1)
            port.data.set_value(value)

        def broadcast_with_reshape(port):
            input_shape = input_port.data.get_shape()
            reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            data_shape = port.data.get_shape()
            for i in range(node.axis, node.axis + len(data_shape)):
                reshape_dims[i] = data_shape[i - node.axis]
            for i in range(node.axis + len(data_shape), len(input_shape)):
                reshape_dims[i] = 1
            reshape = create_op_node_with_second_input(graph, Reshape, reshape_dims,
                                                       dict(name=port.node.name + "/Broadcast_"))
            port.get_connection().set_destination(reshape.in_port(0))
            reshape.out_port(0).connect(port)

        if has_weights and scale_port.data.get_value() is not None:
            broadcast_value(scale_port)
        elif has_weights:
            broadcast_with_reshape(scale_port)

        if has_biases and shift_port.data.get_value() is not None:
            broadcast_value(shift_port)
        elif has_biases:
            broadcast_with_reshape(shift_port)

        if has_biases and has_weights:
            # Connect input->mul->out->add->out
            add_node = add_op.create_node()
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            # Connect Add operation with inputs
            mul_node.out_port(0).connect(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        elif has_weights:
            # Connect input->mul->out
            mul_node = mul_op.create_node()

            # Connect Mul operation with inputs
            input_port.get_connection().set_destination(mul_node.in_port(0))
            scale_port.get_connection().set_destination(mul_node.in_port(1))

            output_port.get_connection().set_source(mul_node.out_port(0))
        elif has_biases:
            # Connect input->add->out
            add_node = add_op.create_node()

            # Connect Add operation with inputs
            input_port.get_connection().set_destination(add_node.in_port(0))
            shift_port.get_connection().set_destination(add_node.in_port(1))

            output_port.get_connection().set_source(add_node.out_port(0))
        else:
            # Connect input->out
            producer_port = input_port.get_source()
            input_port.disconnect()
            output_port.get_connection().set_source(producer_port)
示例#26
0
    def replace_op(self, graph: Graph, node: Node):
        input_out_port = node.in_port(0).get_source()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'out-size': node.gifo_x_weights_shape[0],
            'transpose_weights': True,
            'bias_term': True,
        }

        fc_layer_after_input = FullyConnected(
            graph, fc_layer_after_input_attrs).create_node()
        fc_layer_after_input.in_port(0).connect(input_out_port)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1,
                       'weights', node.gifo_x_weights)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2,
                       'biases', node.gifo_biases)

        init_value_prev_lstm_output = create_const_with_batch_from_input(
            input_out_port, node.gifo_r_weights_shape[1])
        prev_lstm_output = ReadValue(graph, {
            'name': 'prev_memory_output',
            'variable_id': memory_pair_input
        }).create_node()
        prev_lstm_output.in_port(0).connect(
            init_value_prev_lstm_output.out_port(0))

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'out-size': node.gifo_r_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False,
        }

        fc_layer_from_prev_state = FullyConnected(
            graph, fc_layer_from_prev_state_attrs).create_node()
        fc_layer_from_prev_state.in_port(0).connect(
            prev_lstm_output.out_port(0))
        input_as_const(fc_layer_from_prev_state,
                       fc_layer_from_prev_state_attrs, 1, 'weights',
                       node.gifo_r_weights)

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise'
        }).create_node()
        join_input_prev_state_sum.in_port(0).connect(
            fc_layer_from_prev_state.out_port(0))
        join_input_prev_state_sum.in_port(1).connect(
            fc_layer_after_input.out_port(0))

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input_axis = Const(graph, {
            'value': np.int64(1)
        }).create_node()
        split_joined_input = Split(graph, {
            'name': 'join_input_split',
            'num_splits': 4,
            'out_ports_count': 4
        }).create_node()
        split_joined_input.in_port(0).connect(
            join_input_prev_state_sum.out_port(0))
        split_joined_input.in_port(1).connect(
            split_joined_input_axis.out_port(0))

        init_value_prev_lstm_state = create_const_with_batch_from_input(
            split_joined_input.out_port(0), node.input_gate_weights.shape[0])
        prev_lstm_state = ReadValue(graph, {
            'name': 'prev_memory_state',
            'variable_id': memory_pair_output
        }).create_node()
        prev_lstm_state.in_port(0).connect(
            init_value_prev_lstm_state.out_port(0))

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node()
        state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1,
                       'weights', node.input_gate_weights)

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        state_forget_scaleshift = ScaleShiftOp(
            graph, state_forget_scaleshift_attrs).create_node()
        state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs,
                       1, 'weights', node.forget_gate_weights)

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise'
            }).create_node()
        join_prev_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(1))
        join_prev_lstm_input_joined_input_sum.in_port(1).connect(
            state_input_scaleshift.out_port(0))
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node()
        join_prev_lstm_input_joined_forget_sum.in_port(0).connect(
            split_joined_input.out_port(2))
        join_prev_lstm_input_joined_forget_sum.in_port(1).connect(
            state_forget_scaleshift.out_port(0))

        # Split -> Tanh
        remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
        remember_tahn.in_port(0).connect(split_joined_input.out_port(0))

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node()
        remember_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_input_sum.out_port(0))

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node()
        forget_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_forget_sum.out_port(0))

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul'
        }).create_node()
        join_forget_prev_state_mul.in_port(0).connect(
            forget_sigmoid.out_port(0))
        join_forget_prev_state_mul.in_port(1).connect(
            prev_lstm_state.out_port(0))

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul'
            }).create_node()
        join_remember_candidates_mul.in_port(0).connect(
            remember_tahn.out_port(0))
        join_remember_candidates_mul.in_port(1).connect(
            remember_sigmoid.out_port(0))

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum'
        }).create_node()
        join_forget_remember_sum.in_port(0).connect(
            join_forget_prev_state_mul.out_port(0))
        join_forget_remember_sum.in_port(1).connect(
            join_remember_candidates_mul.out_port(0))

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = create_op_with_const_inputs(
            graph, Clamp, {
                1: np.array(-node.clip_value, dtype=np.float32),
                2: np.array(node.clip_value, dtype=np.float32)
            }, {'name': 'join_forget_clamp'}, join_forget_remember_sum)
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Assign(graph, {
            'name': 'next_lstm_state',
            'variable_id': memory_pair_output
        }).create_node()
        next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))

        res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
        res_node.in_port(0).connect(next_lstm_state.out_port(0))

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node()
        state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        clamp_scaleshift = ScaleShiftOp(graph,
                                        clamp_scaleshift_attrs).create_node()
        clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
        input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights',
                       node.output_gate_weights)

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node()
        join_next_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(3))
        join_next_lstm_input_joined_input_sum.in_port(1).connect(
            clamp_scaleshift.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node()
        output_sigmoid.in_port(0).connect(
            join_next_lstm_input_joined_input_sum.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node()
        joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
        joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'out-size': node.projection_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False
        }
        fc_output = FullyConnected(graph, fc_output_attrs).create_node()
        fc_output.in_port(0).connect(joined_output_mul.out_port(0))
        input_as_const(fc_output, fc_output_attrs, 1, 'weights',
                       node.projection_weights)

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Assign(graph, {
            'name': 'next_lstm_output',
            'variable_id': memory_pair_input
        }).create_node()
        next_lstm_output.in_port(0).connect(fc_output.out_port(0))

        res_node_lstm_output = Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node()
        res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))

        return [fc_output.id]
示例#27
0
 def extract(cls, node):
     Mul.update_node_stat(node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)})
     return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        log.debug(
            '================== GNMTBeforeConditionFind ==================')
        input_sequence_lengths = match['Max'].in_port(0).get_source().node
        encoder_sequence_lengths = looking_for_op_in_list([
            port.node
            for port in input_sequence_lengths.out_port(0).get_destinations()
        ], 'Identity')

        # Looking for Sequence_length node in encoder looks like:
        # Sequence_length -> CheckSeqLen -> Max -> Maximum -> Minimum

        check_seq_len = looking_for_op_in_list([
            port.node for port in encoder_sequence_lengths.out_port(
                0).get_destinations()
        ], 'Identity')
        max = looking_for_op_in_list([
            port.node for port in check_seq_len.out_port(0).get_destinations()
        ], 'ReduceMax')
        maximum = max.out_port(0).get_destinations()[0].node
        assert maximum.op == 'Maximum'
        minimum = maximum.out_port(0).get_destinations()[0].node
        assert minimum.op == 'Minimum'

        tensor_seq_len = looking_for_op_in_list([
            minimum.in_port(port).get_source().node
            for port in minimum.in_ports()
        ], 'StridedSlice')

        # Create node for multiplying seq_len by 2
        const = Const(graph, {
            'name': 'FakeSeqLenMultiplyer',
            'value': mo_array(2)
        }).create_node()
        mul_op = Mul(graph, {'name': 'FakeSeqLen'}).create_node()

        const.out_port(0).get_connection().set_destination(mul_op.in_port(1))
        tensor_seq_len.out_port(0).get_connection().add_destination(
            mul_op.in_port(0))

        # Connect seq_len * 2 to TensorArray from GNMT loop
        ta_writes = [
            port.node
            for port in match['Identity_1'].out_port(0).get_destinations()
            if port.node.op == 'TensorArrayWriteV3'
        ]

        for ta_write in ta_writes:
            ta = ta_write.in_port(0).get_source().node.in_port(
                0).get_source().node

            ta.in_port(0).disconnect()
            ta.in_port(0).get_connection().set_source(mul_op.out_port(0))

        if not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method",
                extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True
示例#29
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                dict(name='do_reshape_classes'),
                                                                match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name', initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node()
        end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node()
        stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node()

        priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice',
                                           'begin_mask': int64_array([1, 1, 1]),
                                           'end_mask': int64_array([1, 0, 0]),
                                           'new_axis_mask': int64_array([0]),
                                           'shrink_axis_mask': int64_array([0]),
                                           'ellipsis_mask': int64_array([0])}).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes['variance']
        except:
            raise Error('There is no variance attribute in {} replacement config file `custom_attributes`'
                        ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(
            graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4},
            priors_scale_node)

        priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_')
                                ).create_node([(split_node, 2), (split_node, 0)])
        priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')
                                 ).create_node([(split_node, 3), (split_node, 1)])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {'name': 'concat_width_height',
                                  'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {'name': 'concat_width_height_width',
                                  'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2}
                                          ).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node(
            [concat_width_height_node, match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                   dict(name='reshape_regression'),
                                                                   applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
                 variance_encoded_in_target=0, background_label_id=1000))

        # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
        # correspond to original tensors and should be removed from output->Result edges
        out_nodes = []
        for out in range(match.outputs_count()):
            out_nodes.append(match.output_node(out)[0])
        clear_tensor_names_info(out_nodes)

        return {'detection_output_node': detection_output_node}
示例#30
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))