Ejemplo n.º 1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        name = node.soft_get('name', node.id)

        graph.graph['cmd_params'].static_shape = False

        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
        end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
        stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()

        shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
        ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
        ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}

        node['need_shape_inference'] = True
        node['override_output_shape'] = True
        node['V10_infer'] = True
        unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'})
        naked_priorbox_name = name + '/naked_not_unsqueezed'
        rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)])

        node.out_port(0).get_connection().set_source(unsqueeze.out_port(0))
        node.out_port(0).connect(unsqueeze.in_port(0))
Ejemplo n.º 2
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
        node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            # TODO: remove casting once we start to support I64 model input
            # cast default value to I32 due limitation about I64 input support
            # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
            cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
            node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
            broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
        else:
            broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
                                                            'value': np.float32(0)}
                                                    ).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

        return [scatternd_node.id]
Ejemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        for roll_node in graph.get_op_nodes(op='Roll'):
            if not roll_node.in_port(2).disconnected():
                return
            node_name = roll_node.soft_get('name', roll_node.id)

            # reshape to 1d tensor
            reshape_to_1d = create_op_node_with_second_input(
                graph, Reshape, int64_array([-1]),
                {'name': node_name + '/reshape'})
            roll_node.in_port(0).get_connection().insert_node(reshape_to_1d)

            # add zero const as axes input to roll
            const_zero = Const(graph, {
                'value': int64_array([0]),
                'name': node_name + '/axes'
            }).create_node()
            const_zero.out_port(0).connect(roll_node.in_port(2))

            # reshape to original shape
            shape_of = Shape(graph, {
                'name': node_name + '/shape_of'
            }).create_node()
            roll_node.in_port(0).get_connection().add_destination(
                shape_of.in_port(0))
            reshape_to_orig_shape = Reshape(graph, {}).create_node()
            rename_nodes([(roll_node, node_name + '/roll'),
                          (reshape_to_orig_shape, node_name)])
            shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1))
            roll_node.out_port(0).get_connection().insert_node(
                reshape_to_orig_shape)
Ejemplo n.º 4
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='MVNCaffe'):
            node_name = node.soft_get('name', node.id)

            start_axis = 2
            if node['across_channels'] == 1:
                start_axis = 1

            rank = Rank(graph, {'name': node_name + '/Rank'}).create_node()

            # create range of axes based on `start_axis` and rank of input
            rng = create_op_with_const_inputs(graph, Range, {
                0: int64_array(start_axis),
                2: int64_array(1)
            }, {
                'name': node_name + '/Range',
                'output_type': np.int64
            })
            rng.in_port(1).connect(rank.out_port(0))

            new_mvn = MVN(
                graph, {
                    'eps': node.soft_get('eps', 1e-9),
                    'eps_mode': 'inside_sqrt',
                    'normalize_variance': node.soft_get(
                        'normalize_variance', 1)
                }).create_node([node.in_port(0).get_source().node, rng])
            new_mvn.in_port(0).get_connection().add_destination(
                rank.in_port(0))
            node.out_port(0).get_connection().set_source(new_mvn.out_port(0))
            rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)])

            graph.remove_node(node.id)
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        flatten_node = FlattenONNX(graph, {
            'name': node_name + '/FlattenONNX_',
            'axis': node.axis
        }).create_node()
        shape_node = Shape(graph, {
            'name': node_name + '/ShapeOf_'
        }).create_node()
        logsoftmax_node = LogSoftmax(graph, {
            'name': node_name + '/LogSoftmax_',
            'axis': 1
        }).create_node()
        reshape_node = Reshape(graph, {}).create_node()

        rename_nodes([(node, node_name + '/delete'),
                      (reshape_node, node_name)])

        shape_node.out_port(0).connect(reshape_node.in_port(1))
        logsoftmax_node.out_port(0).connect(reshape_node.in_port(0))
        flatten_node.out_port(0).connect(logsoftmax_node.in_port(0))

        source = node.in_port(0).get_source()

        flatten_node.in_port(0).connect(source)
        shape_node.in_port(0).connect(source)

        return [reshape_node.id]
    def replace_gelu(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit
        # f(x) = 0.5 * x * (1 + erf(x / sqrt(2))
        out_node = match['mul0']
        node_name = out_node.soft_get('name', out_node.id)
        div = match['div']
        inp_node = div.in_port(0).get_source().node
        inp_name = inp_node.soft_get('name', out_node.id)
        log.debug('Found potential Erf-based GeLU pattern after {} with name {}'.format(inp_node.op, inp_name))

        # take the values of the mul, add and div
        div_param = match['div_param']
        add_param = match['add_param']
        mul_param = match['mul_param']

        if add_param.value.size == 1 and mul_param.value.size == 1 and div_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            add_param = match['add_param'].value.item()
            div_param = match['div_param'].value.item()

            sqrt2 = sqrt(2.0)
            # check that the values match the approximation
            if fabs(div_param - sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
                log.debug('Confirmed Erf-based GELU pattern after {} with name {}'.format(inp_node.op, inp_name))
                gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation='erf')).create_node()
                div.in_port(0).get_connection().set_destination(gelu.in_port(0))
                out_node.out_port(0).get_connection().set_source(gelu.out_port(0))
                rename_nodes([(out_node, node_name + '/TBD'), (gelu, node_name)])
Ejemplo n.º 7
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_pad in graph.get_op_nodes(op='AttributedPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = attr_pad.soft_get('name', attr_pad.id)

            new_pad = Pad(graph, {
                'mode': attr_pad.soft_get('mode', None),
            }).create_node()
            rename_nodes([(attr_pad, original_name + '/to_be_removed'),
                          (new_pad, original_name)])

            attr_pad.in_port(0).get_connection().set_destination(
                new_pad.in_port(0))
            new_pad.in_port(1).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 0]
                }).create_node().out_port(0))
            new_pad.in_port(2).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 1]
                }).create_node().out_port(0))
            if attr_pad.soft_get('mode') == 'constant':
                new_pad.in_port(3).connect(
                    Const(graph, {
                        'value': attr_pad.fill_value
                    }).create_node().out_port(0))

            attr_pad.out_port(0).get_connection().set_source(
                new_pad.out_port(0))
            graph.remove_node(attr_pad.id)
Ejemplo n.º 8
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['op']
        name = node.soft_get('name', node.id)

        # biases normalization
        bias_node = Add(graph, {'name': name + '/Bias_', 'can_be_scaleshift': False}).create_node()
        if not graph.graph['cmd_params'].generate_deprecated_IR_V7:
            node_name = node.name + '/WithoutBiases'
            bias_node_name = node.name
            rename_nodes([(node, node_name), (bias_node, bias_node_name)])
        node.out_port(0).get_connection().set_source(bias_node.out_port(0))
        node.in_port(2).get_connection().set_destination(bias_node.in_port(1))
        node.out_port(0).connect(bias_node.in_port(0))

        if node.has_valid('alpha') and not math.isclose(node.alpha, 1):
            bias_node.insert_op_on_input_port(in_port_idx=0, new_op_class=Mul, value=np.array(node.alpha),
                                              new_op_attrs={'name': name + '/Alpha_', 'can_be_scaleshift': False})
            del node['alpha']

        if node.has_valid('beta') and not math.isclose(node.beta, 1):
            bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array(node.beta),
                                              new_op_attrs={'name': name + '/Beta_', 'can_be_scaleshift': False})
            del node['beta']

        MatMul.update_node_stat(node, {
            'transpose_a': node.has_and_set('transpose_a'),
            'transpose_b': node.has_and_set('transpose_b'),
        })
Ejemplo n.º 9
0
def replace_ctc_greedy_decoder(graph: Graph, match: dict):
    ctc_greedy_decoder_tf = match['decoder']
    cast = match['cast']
    sparse_to_dense = match['sparse_to_dense']
    sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
    ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)

    # for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C]
    # which supported CTCGreedyDecoderSeqLen op.
    ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
                                                   {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})

    assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
        'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)

    ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0))
    merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
    ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': sparse_to_dense_name,
                                                          'merge_repeated': merge_repeated_tf}).create_node()
    rename_nodes(
        [(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)])
    ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
    ctc_greedy_decoder_tf.in_port(1).get_source().connect(ctc_greedy_decoder.in_port(1))

    # set output of the new sub-graph as a source for SparseToDense consumer
    sparse_to_dense.out_port(0).get_connection().set_source(ctc_greedy_decoder.out_port(0))

    # remove no longer needed nodes
    graph.remove_nodes_from([sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
Ejemplo n.º 10
0
    def find_and_replace_pattern(self, graph: Graph):
        for fake_output in graph.get_op_nodes(op='FakeOutput'):
            name = fake_output.soft_get('name', fake_output.id)

            producer = fake_output.in_port(0).get_source().node
            producer_outputs = 0
            for port in producer.out_ports().values():
                if not port.disconnected():
                    producer_outputs += 1
            if producer_outputs != 1:
                # At this stage we don't know the type of output, so we rely on MO transformation which updates the
                # Const type for elementwise operations in case of input data types mismatch
                add = create_op_with_const_inputs(graph, Add,
                                                  {1: int64_array(0)},
                                                  {'can_be_fused': False})
                rename_nodes([(fake_output, name + '/TBD'), (add, name)])

                fake_output.in_port(0).get_connection().set_destination(
                    add.in_port(0))
                fake_output.out_port(0).get_connection().set_source(
                    add.out_port(0))
            else:
                result_in_port = fake_output.out_port(0).get_destination()
                result_in_port.disconnect()
                fake_output.in_port(0).get_connection().set_destination(
                    result_in_port)
                rename_nodes([(fake_output, name + '/TBD'), (producer, name)])
Ejemplo n.º 11
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['op']
        name = node.soft_get('name', node.id)

        # biases normalization
        if 2 in node.in_ports() and not node.in_port(2).disconnected():
            bias_node = Add(graph, {'name': name + '/Bias_'}).create_node()
            if not graph.graph['cmd_params'].generate_deprecated_IR_V7:
                node_name = node.name + '/WithoutBiases'
                bias_node_name = node.name
                rename_nodes([(node, node_name), (bias_node, bias_node_name)])
            node.out_port(0).get_connection().set_source(bias_node.out_port(0))
            node.in_port(2).get_connection().set_destination(bias_node.in_port(1))
            node.out_port(0).connect(bias_node.in_port(0))

        # weights normalization
        assert node.has_valid('out-size')
        out_size = node['out-size']
        reshape_dim = int64_array([-1, out_size])
        if node.has_and_set('transpose_weights'):
            reshape_dim = int64_array([out_size, -1])
        node.insert_op_on_input_port(in_port_idx=1, new_op_class=Reshape,
                                     new_op_attrs={'name': name + '/weights_reshape'}, value=reshape_dim)
        if node.has_and_set('transpose_weights'):
            node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose,
                                         new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0]))

        # input normalization for 4D Caffe and MxNet FullyConnected
        if graph.graph['fw'] in ['caffe', 'mxnet']:
            node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape,
                                         new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1]))

        MatMul.update_node_stat(node, {})
Ejemplo n.º 12
0
    def replace_layer_norm(self, graph: Graph, match: dict):
        inp = match['pool0']
        node_before = inp.in_port(0).get_source().node
        node_before_name = node_before.soft_get('name', node_before.id)

        # take/check the values of the add, pow and axes for ReduceMean
        pow_param = match['pow_param']
        add_param = match['add_param']
        if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \
                and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value:
            log.debug('Found LayerNorm pattern after {} with name {}'.format(
                node_before.op, node_before_name))
            mvn = create_op_with_const_inputs(
                graph, MVN, {1: match['pool1_param'].value}, {
                    'eps': add_param.value.item(),
                    'normalize_variance': 1,
                    'eps_mode': 'inside_sqrt'
                })
            div_name = match['div'].soft_get('name', match['div'].id)
            rename_nodes([(match['div'], div_name + '/to_be_removed'),
                          (mvn, div_name)])

            inp.in_port(0).get_connection().set_destination(mvn.in_port(0))
            match['div'].out_port(0).get_connection().set_source(
                mvn.out_port(0))
Ejemplo n.º 13
0
    def replace_pattern(self, graph: Graph, match: dict):
        mul_node = match['mul_op']
        const_node = match['const_op']
        max_node = match['max_op']
        max_name = max_node.soft_get('name', max_node.id)

        const_value = const_node.out_port(0).data.get_value()
        if const_value is None or const_value.size != 1:
            log.debug(
                'Mul layer "{}" can not participate in conversion to the LeakyReLU because constant "{}" '
                'contains more than one element: {}'.format(
                    mul_node.id, const_node.id, const_value.size))
            return

        # Create new LeakyReLU operation
        leaky_relu_node = LeakyReLU(
            graph, dict(negative_slope=const_value.item(0))).create_node()

        data_in_port = int(
            mul_node.in_port(0).get_source().node.type == 'Const')
        mul_node.in_port(data_in_port).get_source().connect(
            leaky_relu_node.in_port(0))
        max_node.out_port(0).get_connection().set_source(
            leaky_relu_node.out_port(0))

        rename_nodes([(max_node, max_name + '/TBR'),
                      (leaky_relu_node, max_name)])

        log.debug(
            'Successful conversion from {} {} to ReLU with negative slope (leaky ReLU)'
            ''.format(max_node.id, mul_node.id))
Ejemplo n.º 14
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        beta = match['beta']
        mul = match['mul']
        mul_beta = match['mul_beta']
        mul_name = mul.soft_get('name', mul.id)

        # determine the input port of Muls which get the 'input' node output
        mul_beta_input_port_idx = int(
            mul_beta.in_port(0).get_connection().get_source().node.id ==
            beta.id)
        mul_input_port_idx = int(
            mul.in_port(0).get_connection().get_source().node.soft_get('op') ==
            'Sigmoid')

        # check that the same tensor provided as input to Mul and MulBeta
        if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(
                mul_beta_input_port_idx).get_source():
            return

        swish = Swish(graph, {}).create_node()
        swish.in_port(0).connect(
            mul_beta.in_port(mul_beta_input_port_idx).get_source())

        # connect Beta value
        swish.in_port(1).connect(
            mul_beta.in_port(1 - mul_beta_input_port_idx).get_source())

        mul.out_port(0).get_connection().set_source(swish.out_port(0))

        rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Interpolate', version='opset1'):
            transformation_mode = 'align_corners' if int(
                node.soft_get('align_corners', 0)) else 'half_pixel'
            interpolate1_name = node.soft_get('name', node.id)
            interpolate4 = create_op_with_const_inputs(
                graph, Interpolate, {
                    2: np.array([1.0, 1.0]),
                    3: int64_array(node.axes)
                }, {
                    'mode': node.mode,
                    'antialias': node.antialias,
                    'coordinate_transformation_mode': transformation_mode,
                    'pads_begin': correct_pad(node.soft_get('pads_begin', 0)),
                    'pads_end': correct_pad(node.soft_get('pads_end', 0)),
                    'nearest_mode': 'round_prefer_floor',
                    'cube_coeff': -0.75,
                    'shape_calculation_mode': 'sizes',
                    'version': 'opset4',
                    'in_ports_count': 4,
                })

            interpolate1_input_connection = node.in_port(0).get_connection()
            interpolate1_input_connection.set_destination(
                interpolate4.in_port(0))

            sizes_connection = node.in_port(1).get_connection()
            sizes_connection.set_destination(interpolate4.in_port(1))

            node.out_port(0).get_connection().set_source(
                interpolate4.out_port(0))
            rename_nodes([(node, interpolate1_name + '/delete'),
                          (interpolate4, interpolate1_name)])
Ejemplo n.º 16
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if node.has_port('in', 2) and not node.in_port(
                2).disconnected() and not node.has_and_set('shape_input'):
            bias_name = node.name
            new_node_name = node.name + '/WithoutBiases'
            add = Add(graph, dict(name=bias_name)).create_node()
            rename_nodes([(node, new_node_name), (add, bias_name)])
            node.out_port(0).get_connection().set_source(add.out_port(0))
            node.out_port(0).connect(add.in_port(0))
            node.in_port(2).get_connection().set_destination(add.in_port(1))

            bias = add.in_port(1).get_source().node
            if bias.has_valid("type") and bias.type == "Const":
                input_shape = add.in_port(0).data.get_shape()
                if len(input_shape) > 2:
                    dims_to_add = len(input_shape) - 2 if graph.graph[
                        'layout'] == 'NCHW' else 0
                    if dims_to_add > 0:
                        reshape = create_op_node_with_second_input(
                            graph, Reshape,
                            np.array([input_shape[1]] + [1] * dims_to_add,
                                     dtype=np.int64),
                            {'name': node.id + '/Dims'})
                        add.in_port(1).get_connection().set_destination(
                            reshape.in_port(0))
                        reshape.out_port(0).connect(add.in_port(1))
Ejemplo n.º 17
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_pad in graph.get_op_nodes(op='AttributedPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = attr_pad.soft_get('name', attr_pad.id)

            new_pad = Pad(graph, {
                'mode': attr_pad.soft_get('mode', None),
            }).create_node()
            rename_nodes([(attr_pad, original_name + '/to_be_removed'),
                          (new_pad, original_name)])

            attr_pad.in_port(0).get_connection().set_destination(
                new_pad.in_port(0))
            new_pad.in_port(1).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 0]
                }).create_node().out_port(0))
            new_pad.in_port(2).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 1]
                }).create_node().out_port(0))
            if attr_pad.soft_get('mode') == 'constant':
                # create Constant node of proper data type (equal to the data type of the Pad first input)
                convert_pad_value = create_op_with_const_inputs(
                    graph, ConvertLike, {0: attr_pad.fill_value},
                    {'name': original_name + '/pad_value_convert'})
                convert_pad_value.in_port(1).connect(
                    new_pad.in_port(0).get_source())
                new_pad.in_port(3).connect(convert_pad_value.out_port(0))

            attr_pad.out_port(0).get_connection().set_source(
                new_pad.out_port(0))
            graph.remove_node(attr_pad.id)
Ejemplo n.º 18
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        cmp = match['complex']
        complex_abs = match['abs']
        complex_abs_name = complex_abs.soft_get('name', complex_abs.id)

        power_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

        pow0 = create_op_with_const_inputs(
            graph, Pow, {1: power_type(2.0)},
            {'name': complex_abs_name + '/real_part_squared'})
        pow1 = create_op_with_const_inputs(
            graph, Pow, {1: power_type(2.0)},
            {'name': complex_abs_name + '/imag_part_squared'})

        cmp.in_port(0).get_connection().set_destination(pow0.in_port(0))
        cmp.in_port(1).get_connection().set_destination(pow1.in_port(0))

        add = Add(graph, {
            'name': complex_abs_name + '/squared_abs'
        }).create_node([pow0, pow1])
        sqrt = create_op_with_const_inputs(graph, Pow, {1: power_type(0.5)},
                                           {})
        add.out_port(0).connect(sqrt.in_port(0))

        complex_abs.out_port(0).get_connection().set_source(sqrt.out_port(0))

        rename_nodes([(complex_abs, complex_abs_name + '/to_be_removed'),
                      (sqrt, complex_abs_name)])
Ejemplo n.º 19
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        identity_spw = match['identity_spw']
        gather0_1 = match['gather0_1']
        gather0_2 = match['gather0_2']
        greaterequal0 = match['greaterequal0']
        sparse_fill_empty_rows = match['sparse_fill_empty_rows']
        gather = match['gather']
        select = match['select']
        where0 = match['where0']
        output_node_name = select.soft_get('name', select.id)

        log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
                                                                                     sparse_fill_empty_rows.name))

        split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)},
                                                        {'num_splits': 2,
                                                         'name': output_node_name + '/SplitForIndices'})
        squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
        split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)},
                                                            {'num_splits': 2,
                                                             'name': output_node_name + '/SplitForDenseShape'})
        squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
        cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
        cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
        cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
        embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
        rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])

        # connect parameters table
        gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
        # connect indices values
        greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
        # split and connect segment ids
        gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
        squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
        # TODO: remove casting once we start to support I64 model input
        cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
        embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
        # split and connect number of segments
        identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
        squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
        # TODO: remove casting once we start to support I64 model input
        cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
        embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
        # connect default value
        # TODO: remove casting once we start to support I64 model input
        sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
        embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
        # no input port for per_sample_weight

        identity_spw.in_port(0).disconnect()
        gather0_1.in_port(0).disconnect()
        gather0_2.in_port(0).disconnect()
        greaterequal0.in_port(0).disconnect()
        sparse_fill_empty_rows.in_port(2).disconnect()
        gather.in_port(0).disconnect()

        select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
        graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
Ejemplo n.º 20
0
    def convert_fft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)
        unsqueeze_node = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array([-1])},
            {'name': mx_fft_name + '/Unsqueeze'})
        rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node()

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(unsqueeze_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Add'},
                                               rank_node)
        broadcast_node1 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Pad_broadcast'})
        add_node.out_port(0).connect(broadcast_node1.in_port(1))

        scatter_node = create_op_with_const_inputs(
            graph, ScatterUpdate, {
                2: int64_array(1),
                3: int64_array(0)
            }, {'name': mx_fft_name + '/ScatterUpdate'})
        broadcast_node1.out_port(0).connect(scatter_node.in_port(0))
        rank_node.out_port(0).connect(scatter_node.in_port(1))

        pad_node = Pad(graph, {
            'name': mx_fft_name + '/Pad',
            'mode': 'constant'
        }).create_node([unsqueeze_node, broadcast_node1, scatter_node])

        dft_node = create_op_with_const_inputs(
            graph, DFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/DFT',
                'in_ports_count': 2
            }, pad_node)

        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node2 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Reshape_broadcast'})
        sub_node.out_port(0).connect(broadcast_node2.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/New_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node2)

        reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node])

        mx_fft.out_port(0).get_connection().set_source(
            reshape_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (reshape_node, mx_fft_name)])
Ejemplo n.º 21
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_random_uniform in graph.get_op_nodes(
                op='AttributedRandomUniform'):
            original_name = attr_random_uniform.soft_get(
                'name', attr_random_uniform.id)

            if not attr_random_uniform.has_valid('output_type'):
                raise Error(
                    "RandomUniform should have valid ''output_type'' attribute."
                )
            output_type = attr_random_uniform.soft_get('output_type')

            if attr_random_uniform.has_valid('min_val'):
                min_val = attr_random_uniform['min_val']
            else:
                min_val = output_type(0)
            if attr_random_uniform.has_valid('max_val'):
                max_val = attr_random_uniform['max_val']
            else:
                max_val = output_type(1)

            port_value_dict = {1: min_val, 2: max_val}

            if not attr_random_uniform.has_port(
                    'in', 0) or attr_random_uniform.in_port(0).disconnected():
                if not attr_random_uniform.has_valid('shape'):
                    raise Error(
                        "RandomUniform should have valid ''shape'' attribute or input node on 0 port."
                    )
                else:
                    port_value_dict.update({0: attr_random_uniform.shape})

            attrs = {
                'global_seed': attr_random_uniform.soft_get('global_seed', 0),
                'op_seed': attr_random_uniform.soft_get('op_seed', 0),
                'output_type': output_type
            }

            new_random_uniform = create_op_with_const_inputs(
                graph,
                op=RandomUniform,
                port_value_dict=port_value_dict,
                op_attrs=attrs)
            rename_nodes([(attr_random_uniform,
                           original_name + '/to_be_removed'),
                          (new_random_uniform, original_name)])
            attr_random_uniform.out_port(0).get_connection().set_source(
                new_random_uniform.out_port(0))
            if new_random_uniform.in_port(0).disconnected():
                if attr_random_uniform.in_port(0).disconnected():
                    raise Error(
                        'RandomUniform should have input node on 0 port.')
                else:
                    new_random_uniform.in_port(0).connect(
                        attr_random_uniform.in_port(
                            0).get_connection().get_source())

            graph.remove_node(attr_random_uniform.id)
Ejemplo n.º 22
0
    def transform_keras_rnn_output_concatenation(external_match: dict,
                                                 internal_match: dict):
        """
        Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
        :param external_match: a match used for handling a part of the main graph responsible for output concatenation
        :param internal_match: a match used for handling a part of the body graph responsible for output concatenation
        """
        loop_node = external_match['while']
        stack_node = external_match['stack']
        list_reserve_node = external_match['reserve']
        body_graph = loop_node['body']

        tensor_list_set_item_node = internal_match['concatenation']
        tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get(
            'name', tensor_list_set_item_node.id)
        list_result_node = internal_match['concatenation_result']

        # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node
        # to concatenate results from different iterations
        unsqueeze_list_element = create_op_with_const_inputs(
            body_graph, Unsqueeze, {1: int64_array(0)},
            {'name': 'TensorListSetItemUnsqueeze'})
        tensor_list_set_item_node.in_port(2).get_connection().set_destination(
            unsqueeze_list_element.in_port(0))
        tensor_list_set_item_node.out_port(0).get_connection().set_source(
            unsqueeze_list_element.out_port(0))
        rename_nodes([(tensor_list_set_item_node,
                       tensor_list_set_item_node_name + '/AbandonedName'),
                      (unsqueeze_list_element, tensor_list_set_item_node_name)
                      ])
        list_result_node_layer_id = list_result_node.internal_layer_id
        Loop.update_port_map_value_ext(loop_node.output_port_map,
                                       'internal_layer_id',
                                       list_result_node_layer_id, 'axis', 0)

        # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated
        stack_node.out_port(0).get_connection().set_source(
            stack_node.in_port(0).get_connection().get_source())

        # disconnect ListReserve node because it is no longer needed for Loop
        list_reserve_node.out_port(0).disconnect()

        # connect a number of iterations with trip count that can be received from the second input of ListReserve
        # create a constant network with True value for execution_condition so that IE can ignore execution condition
        # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism.
        loop_node.in_port(1).disconnect()
        list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1))
        for record in loop_node.output_port_map:
            if 'purpose' in record and record[
                    'purpose'] == 'execution_condition':
                exec_cond_layer_id = record['internal_layer_id']
                exec_cond_node = Loop.get_body_node_by_internal_id(
                    loop_node, exec_cond_layer_id)
                const_true = Const(body_graph, {
                    'value': np.array(True, dtype=np.bool)
                }).create_node()
                exec_cond_node.in_port(0).get_connection().set_source(
                    const_true.out_port(0))
Ejemplo n.º 23
0
def replace_strided_slice(node: Node, mask: np.ndarray, op: callable):
    node_name = node.soft_get('name', node.id)
    axes = np.where(mask == 1)[0]
    new_node = create_op_node_with_second_input(node.graph, op, int64_array(axes))
    node.in_port(0).get_connection().set_destination(new_node.in_port(0))
    node.out_port(0).get_connection().set_source(new_node.out_port(0))

    rename_nodes([(node, node_name + '/ShouldBeDeleted'), (new_node, node_name)])
    node.graph.remove_node(node.id)
Ejemplo n.º 24
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        non_zero_node = NonZero(graph, {'name': node_name + '/NonZero_', 'output_type': np.int64}).create_node()
        transpose_node = create_op_node_with_second_input(graph, Transpose, int64_array([1, 0]), op_attrs={})
        non_zero_node.out_port(0).connect(transpose_node.in_port(0))
        rename_nodes([(node, node_name + '/delete'), (transpose_node, node_name)])

        non_zero_node.in_port(0).connect(node.in_port(0).get_source())
        return [transpose_node.id]
    def replace_sub_graph(self, graph: Graph, match: dict):
        # TODO: Once Inference Engine's CTCGreedyDecoder starts to support sequence length format like in TensorFlow,
        # CTCGreedyDecoderReplacement2 needs to be removed and CTCGreedyDecoderReplacement, a more generic
        # transformation, needs to be adopted for all cases
        ctc_greedy_decoder = match['decoder']
        cast = match['cast']
        sparse_to_dense = match['sparse_to_dense']
        sparse_to_dense_name = sparse_to_dense.soft_get(
            'name', sparse_to_dense.id)

        # disconnect SparseToDense and Cast nodes
        sparse_to_dense.in_port(0).disconnect()
        cast.in_port(0).disconnect()

        # transform CTCGreedyDecoder output to TensorFlow's one:
        # 1. squeeze the output to [N, T] shape
        # 2. cast it to integer
        squeeze_dec_seq = create_op_with_const_inputs(
            graph, Squeeze, {1: int64_array([2, 3])},
            {'name': sparse_to_dense_name})
        squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_to_int = Cast(graph, {
            'name': sparse_to_dense_name + '/CastToInt',
            'dst_type': np.int32
        }).create_node()
        cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))

        # preserve output name from original graph
        rename_nodes([(sparse_to_dense,
                       sparse_to_dense_name + '/AbandonedName'),
                      (cast_to_int, sparse_to_dense_name)])

        # set output of the new sub-graph as a source for SparseToDense consumer
        sparse_to_dense.out_port(0).get_connection().set_source(
            cast_to_int.out_port(0))

        # remove no longer needed nodes
        graph.remove_nodes_from([sparse_to_dense.id, cast.id])

        # mark CTCGreedyDecoder node as a node that requires transformation of sequence length to a mask format
        # in the middle phase
        ctc_greedy_decoder['use_mask_format'] = True

        # unless the second input of CTCGreedyDecoder is a parameter, it enforces MO to use --static-shape
        # to try getting the second input with a value
        sequence_length_node = ctc_greedy_decoder.in_node(1)
        if sequence_length_node.soft_get(
                'op'
        ) != 'Parameter' and not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method",
                extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True
Ejemplo n.º 26
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        ln = match['ln']
        exp = match['exp']

        ln_name = ln.soft_get('name', ln.id)

        softplus = SoftPlus(graph, {}).create_node()
        softplus.in_port(0).connect(exp.in_port(0).get_source())
        ln.out_port(0).get_connection().set_source(softplus.out_port(0))

        rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):

        q = match['quantize']
        dq = match['dequantize']

        q_scale = q.in_port(1).get_source().node
        q_zerop = q.in_port(2).get_source().node
        dq_scale = dq.in_port(1).get_source().node
        dq_zerop = dq.in_port(2).get_source().node

        inp_port = q.in_port(0).get_source()
        name = inp_port.node.soft_get('name', inp_port.node.id)

        # only constant as for zero_point/scale supported
        if q_scale.soft_get('type') == 'Const' and dq_scale.soft_get('type') == 'Const' and \
                q_zerop.soft_get('type') == 'Const' and dq_zerop.soft_get('type') == 'Const':

            # only patterns with same scale/zero_point values for Q and DQ are supported
            if q_scale.value == dq_scale.value and q_zerop.value == dq_zerop.value:
                log.debug('Found Q-DQ pattern after {}'.format(name))

                zero_point_type = q_zerop.value.dtype
                # data type affects range of output values: [-128..127] or [0..255]
                if zero_point_type == np.int8:
                    output_min_value = -128.0
                    output_max_value = 127.0
                elif zero_point_type == np.uint8:
                    output_min_value = 0.0
                    output_max_value = 255.0
                else:
                    raise Error('Not supported type {} for zero point value in node {}'.format(
                        zero_point_type, q_zerop.soft_get('name')))
                min_value = q_scale.value * (output_min_value - q_zerop.value)
                max_value = q_scale.value * (output_max_value - q_zerop.value)
                input_min = Const(graph, {'value': np.array(min_value)}).create_node()
                input_max = Const(graph, {'value': np.array(max_value)}).create_node()

                FQ = FakeQuantize(graph, {
                    'levels': 256,
                    'name': match['quantize'].name + '_Dequantize/FakeQuantize'
                }).create_node()

                FQ.in_port(0).connect(match['quantize'].in_port(0).get_source())
                FQ.in_port(1).connect(input_min.out_port(0))
                FQ.in_port(2).connect(input_max.out_port(0))
                FQ.in_port(3).connect(input_min.out_port(0))
                FQ.in_port(4).connect(input_max.out_port(0))

                match['dequantize'].out_port(0).get_connection().set_source(FQ.out_port(0))
                dq_name = match['dequantize'].soft_get('name', match['dequantize'].id)
                rename_nodes([(match['dequantize'], dq_name + '/to_be_removed'), (FQ, dq_name)])
            else:
                raise Error('QuantizeLinear and DequantizeLinear (after {}) have different scale or zero-point values, '
                            'cannot fuse into FakeQuantize!'.format(name))
    def replace_pattern(self, graph: Graph, match: dict):
        if not self.is_applicable(match):
            return

        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id)
        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        d_idx = int(second_input_of_unsqueeze.value)
        axis = d_idx - 1

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis])

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        scale = int64_array([second_input_of_tile.value[d_idx]])
        float_scale = float32_array([second_input_of_tile.value[d_idx]])
        mul_node = create_op_with_const_inputs(
            graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'})

        axis_len_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = create_op_with_const_inputs(
            graph, Interpolate, {
                2: float_scale,
                3: int64_array([axis])
            }, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4,
                'maybe_part_of_sequence': True
            })
        mul_node.out_port(0).connect(interp_node.in_port(1))

        reshape_node = match['reshape']
        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = unsqueeze_node.in_port(0).get_connection()
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        unsqueeze_connection.get_source().connect(shape_node.in_port(0))
Ejemplo n.º 29
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                sub = Sub(graph, {'name': node_name + '/Sub'}).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
Ejemplo n.º 30
0
    def replace_op(self, graph: Graph, node: Node):
        out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports())}).create_node()
        pack_name = node.soft_get('name', node.id)

        for ind in node.in_ports():
            unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array([node.axis])},
                                                         {'name': node.soft_get('name', node.id) + '/Unsqueeze'})
            node.in_port(ind).get_connection().set_destination(unsqueeze_node.in_port(0))
            unsqueeze_node.out_port(0).connect(out_node.in_port(ind))

        rename_nodes([(node, pack_name + '/TBR'), (out_node, pack_name)])
        return [out_node.id]