def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'
        interp_axes = Interpolate.get_axes(interpolate)
        concat_axis = self.get_concat_axis(concat)

        if concat_axis is None or interp_axes is None \
                or np.any(interp_axes < 0) or concat_axis < 0 \
                or concat_axis in interp_axes:
            # checks that interpolate axes and concat axis are valid and do not intersect
            return

        non_interp_concat_srcs = self.get_non_interpolate_concat_sources(
            concat)
        if not len(non_interp_concat_srcs):
            # there is no Concat input to take input from
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {
            'name': src.node.soft_get('name', src.node.id) + '/Shape'
        }).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(
            graph,
            Gather, {
                1: np.array(interp_axes, dtype=np.int32),
                2: int64_array(0)
            }, {'name': shape.name + '/Gathered'},
            input_node=shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for roll_node in graph.get_op_nodes(op='Roll'):
            if not roll_node.in_port(2).disconnected():
                return
            node_name = roll_node.soft_get('name', roll_node.id)

            # reshape to 1d tensor
            reshape_to_1d = create_op_node_with_second_input(
                graph, Reshape, int64_array([-1]),
                {'name': node_name + '/reshape'})
            roll_node.in_port(0).get_connection().insert_node(reshape_to_1d)

            # add zero const as axes input to roll
            const_zero = Const(graph, {
                'value': int64_array([0]),
                'name': node_name + '/axes'
            }).create_node()
            const_zero.out_port(0).connect(roll_node.in_port(2))

            # reshape to original shape
            shape_of = Shape(graph, {
                'name': node_name + '/shape_of'
            }).create_node()
            reshape_to_1d.in_port(0).get_connection().add_destination(
                shape_of.in_port(0))
            reshape_to_orig_shape = Reshape(graph, {}).create_node()
            rename_nodes([(roll_node, node_name + '/roll'),
                          (reshape_to_orig_shape, node_name)])
            shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1))
            roll_node.out_port(0).get_connection().insert_node(
                reshape_to_orig_shape)
예제 #3
0
 def extract(cls, node: Node):
     Shape.update_node_stat(
         node, {
             'output_type':
             tf_dtype_extractor(node.pb.attr['out_type'].type, np.int32)
         })
     return cls.enabled
예제 #4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        div_sqrt = match['op']
        div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id)
        shape_node = Shape(graph,
                           dict(name=div_sqrt_name + '/Shape')).create_node()
        data_out_port = div_sqrt.in_port(0).get_source()
        shape_node.in_port(0).connect(data_out_port)

        shape_values_node = node_to_get_shape_value_of_indices(
            shape_node=shape_node, indices=[-1])

        pow_node = AttributedPower(
            graph, dict(name=div_sqrt_name + '/Sqrt',
                        power=mo_array(0.5))).create_node()

        # Due to specification, Power must have inputs with the same data type.
        convert_pow_input = Cast(
            graph,
            dict(dst_type=np.float32,
                 name=shape_values_node.name +
                 '/ConvertToFP32')).create_node()
        div_node = Div(graph, dict(name="Div")).create_node()

        shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
        convert_pow_input.out_port(0).connect(pow_node.in_port(0))
        div_sqrt.in_port(0).get_connection().set_destination(
            div_node.in_port(0))
        div_node.in_port(1).connect(pow_node.out_port(0))
        div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))

        rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'),
                      (div_node, div_sqrt_name)])
예제 #5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
예제 #6
0
    def squeeze_initial_states(graph: Graph, match: dict):
        """
        Squeeze input initial states of recurrent node to 2-D shape.
        """
        hidden_init_port = 5
        cell_init_port = 6

        rnn_layer = match['rnn_layer']
        # Add input ports to rnn_layer
        rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)

        assert hidden_init_port in rnn_layer.in_nodes()
        hidden_size = rnn_layer.hidden_size
        shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
        rnn_layer.in_port(0).get_source().connect(shape.in_port(0))

        batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
        new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
                                                   op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
                                                                 in_ports_count=2, axis=0), input_node=batch)
        reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
        new_dim.out_port(0).connect(reshape_h.in_port(1))
        rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)

        if rnn_layer.op == 'LSTM':
            assert cell_init_port in rnn_layer.in_nodes()

            reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
            new_dim.out_port(0).connect(reshape_c.in_port(1))
            rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
예제 #7
0
    def make_interpolate_reshapeable(interpolate, concat):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'

        output_shape = interpolate.out_port(0).data.get_shape()

        interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)]
        concat_axis = get_canonical_axis_index(output_shape, concat.axis)
        if concat_axis in interp_axes:
            return

        concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()]
        non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
        if len(non_interp_concat_srcs) == 0:
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(graph, Gather,
                                             {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)},
                                             {'name': shape.name + '/Gathered'}, shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
예제 #8
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'
        ), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(
            name)
        axis = node.axis

        reshape_node = Reshape(graph, {
            'name': node.id + '/Reshape'
        }).create_node()

        if axis == 0:
            dim = Const(
                graph, {
                    'value': int64_array([1, -1]),
                    'name': reshape_node.name + '/shape'
                }).create_node()
        elif axis == 1:
            dim = Const(
                graph, {
                    'value': int64_array([0, -1]),
                    'name': reshape_node.name + '/shape'
                }).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(
                shape, idxs)
            first_dims = create_op_node_with_second_input(
                graph, ReduceProd, int64_array([0]), {
                    'name': name + '/first_dims',
                    'keep_dims': True
                })
            second_dims = Const(graph, {
                'value': int64_array([-1]),
                'name': name + '/second_dims'
            }).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims
                             ] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
예제 #9
0
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]),
                                       'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                       'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
예제 #10
0
    def find_and_replace_pattern(self, graph: Graph):
        reverse_nodes = graph.get_op_nodes(op='Reverse')
        for reverse in reverse_nodes:
            reverse_name = reverse.soft_get('name', reverse.id)

            assert reverse.in_port(1).disconnected()
            assert reverse.has_valid('axis')

            in_shape_rank = len(reverse.in_port(0).data.get_shape())
            # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
            if in_shape_rank == 1:
                unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                             {'name': reverse_name+"/Unsqueeze"})
                reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
                new_in = unsq_node.out_port(0)
                batch_axis = 0
                seq_axis = 1
            else:
                new_in = reverse.in_port(0).get_source()
                seq_axis = reverse['axis']
                batch_axis = 0 if seq_axis != 0 else 1

            # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
            # shape[seq_axis] broadcasted to shape[batch_axis]
            # in ---> ShapeOf ----> Gather(seq_axis)  ----> Broadcast----->
            #            |                                      |
            #            | -------> Gather(batch_axis)----------|
            shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node()
            new_in.connect(shape_node.in_port(0))
            seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis])
            batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis])
            broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node()
            broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
            broadcast_node.in_port(1).connect(batch_node.out_port(0))

            # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
            rename_node(reverse, reverse_name + '/to_delete')
            reverse_sequence = ReverseSequence(graph, {'name':  reverse_name, 'seq_axis': seq_axis,
                                                       'batch_axis': batch_axis}).create_node()
            reverse_sequence.in_port(0).connect(new_in)
            reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))

            # 4. remove added dimension for rank = 1
            if in_shape_rank == 1:
                rename_node(reverse_sequence, reverse_name + '/ReverseSequence')
                squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                {'name': reverse_name})
                squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
                reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
            else:
                reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))

        # 5. Delete old Reverse node
        graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
    def replace_pattern(self, graph: Graph, match: dict):
        if not self.is_applicable(match):
            return

        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id)
        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        d_idx = int(second_input_of_unsqueeze.value)
        axis = d_idx - 1

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis])

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        scale = int64_array([second_input_of_tile.value[d_idx]])
        float_scale = float32_array([second_input_of_tile.value[d_idx]])
        mul_node = create_op_with_const_inputs(
            graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'})

        axis_len_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = create_op_with_const_inputs(
            graph, Interpolate, {
                2: float_scale,
                3: int64_array([axis])
            }, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4,
                'maybe_part_of_sequence': True
            })
        mul_node.out_port(0).connect(interp_node.in_port(1))

        reshape_node = match['reshape']
        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = unsqueeze_node.in_port(0).get_connection()
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        unsqueeze_connection.get_source().connect(shape_node.in_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = float32_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
예제 #13
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                             {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                                       {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {'name': node_name + '/Select',
                                      'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(graph, Range,
                                                 {0: mo_array(0, dtype=value.dtype),
                                                  2: mo_array(1, dtype=value.dtype)},
                                                 {'name': const_name + '/Range', 'dtype': value.dtype})
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, one_dims,
                                                         {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'), (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'), (range_node, const_name)])
예제 #14
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
                node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
            if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

                    weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

                    weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
                    last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
                                                                {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
                                                                          {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
                                                    'name': node_name + '/Concat'}).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
예제 #15
0
 def resolve_minus2(self, shape_node, input_index, reshape_index, dims):
     rank_node = Shape(
         shape_node.graph,
         dict(name=shape_node.id +
              '/RankShapeMXReshapeMinus2')).create_node()
     rank_node.in_port(0).connect(shape_node.out_port(0))
     shape_values_node = get_shape_values_by_range_idxs(shape=shape_node,
                                                        rank=rank_node,
                                                        begin=input_index,
                                                        end=-1,
                                                        include_begin=True,
                                                        include_end=True)
     input_index = None
     reshape_index = reshape_index + 1
     return input_index, reshape_index, dims, shape_values_node
예제 #16
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        mxreshape = match['op']
        if not mxreshape.reverse:
            return

        shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node()
        forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                          dict(name=str(mxreshape.id) + '/ForwardUnsqueeze'))
        forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node()

        forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                        dict(name=str(mxreshape.id) + '/ForwardSqueeze'))
        reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node()
        shape_node.in_port(0).connect(mxreshape.in_port(0).get_source())
        mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0))

        forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0))
        forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0))
        forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0))
        reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0))

        reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)),
                                                              dict(name=str(mxreshape.id) + '/ReshapeShape'))
        if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0):
            reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape',
                                     dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node()

        reshape_shape_node.in_port(0).connect(reshape_node.out_port(0))

        backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node()
        backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                           dict(name=str(mxreshape.id) + '/BackwardUnsqueeze'))
        backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node()
        backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                         dict(name=str(mxreshape.id) + '/BackwardSqueeze'))
        backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node()

        backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0))
        backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0))
        backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0))

        backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0))

        mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
예제 #17
0
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': int64_array([-2])}).create_node()
        end = Const(graph, {'value': int64_array([-1])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]),
                                                     'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                                     'shrink_axis_mask': int64_array([0]),
                                                     'ellipsis_mask': int64_array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
                                                        {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                                         'axis': int64_array(0)},
                                                        shape_part_for_tiling)

        variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
예제 #18
0
 def make_interpolate_reshapeable(interpolate):
     assert interpolate.soft_get('type') == 'Interpolate'
     axes = Interpolate.get_axes(interpolate)
     input_shape = interpolate.in_port(0).data.get_shape()
     output_shape = interpolate.out_port(0).data.get_shape()
     if not np.all(np.remainder(output_shape, input_shape) == 0) and \
             not np.all(np.remainder(input_shape, output_shape) == 0):
         return
     graph = interpolate.graph
     name = interpolate.soft_get('name', interpolate.id)
     shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
     shape.in_port(0).connect(interpolate.in_port(0).get_source())
     gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
                                          {'name': shape.name + '/Gathered'}, shape)
     multipliers = output_shape[axes] / input_shape[axes]
     mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
     interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
예제 #19
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
예제 #20
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name)
        assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node()
            middle_dim = Const(graph, {'value': int64_array([-1])}).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=0, end=axis)
            middle_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True)
            end_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True)

            middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        flatten_node = FlattenONNX(graph, {
            'name': node_name + '/FlattenONNX_',
            'axis': node.axis
        }).create_node()
        shape_node = Shape(graph, {
            'name': node_name + '/ShapeOf_'
        }).create_node()
        logsoftmax_node = LogSoftmax(graph, {
            'name': node_name + '/LogSoftmax_',
            'axis': 1
        }).create_node()
        reshape_node = Reshape(graph, {}).create_node()

        rename_nodes([(node, node_name + '/delete'),
                      (reshape_node, node_name)])

        shape_node.out_port(0).connect(reshape_node.in_port(1))
        logsoftmax_node.out_port(0).connect(reshape_node.in_port(0))
        flatten_node.out_port(0).connect(logsoftmax_node.in_port(0))

        source = node.in_port(0).get_source()

        flatten_node.in_port(0).connect(source)
        shape_node.in_port(0).connect(source)

        return [reshape_node.id]
예제 #22
0
    def replace_pattern(self, graph: Graph, match: dict):
        matmul = match['matmul']
        reshape = match['reshape']
        other_input_port_idx = 0 if match['matmul'].in_port(
            0).get_source().node.id == match['other_input'].id else 1
        shape_source = match['matmul'].in_port(
            other_input_port_idx).get_source()
        initial_reshape_pattern = reshape.in_port(1).data.get_value()
        if len(initial_reshape_pattern) != 2:
            return

        reshape_is_A_input = matmul.in_port(
            0).get_source().node.id == reshape.id
        if reshape_is_A_input:
            idx = -1 if matmul.transpose_b else -2
        else:
            idx = -2 if matmul.transpose_a else -1
        idx = get_canonical_axis_index(initial_reshape_pattern, idx)

        shape_name = shape_source.node.soft_get('name', shape_source.node.id)
        shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node()
        shape.in_port(0).connect(shape_source)
        C = node_to_get_shape_value_of_indices(shape, [idx])
        N = Const(graph, {
            'name': shape_name + '/MinusOne',
            'value': int64_array([-1])
        }).create_node()

        if len(initial_reshape_pattern) == 2:
            if reshape_is_A_input:
                reshape_pattern = [C, N] if matmul.transpose_a else [N, C]
            else:
                reshape_pattern = [N, C] if matmul.transpose_b else [C, N]
            new_reshape_pattern = new_shape_node_from_shape_nodes(
                reshape_pattern)
            reshape.in_port(1).get_connection().set_source(
                new_reshape_pattern.out_port(0))
        else:
            return
예제 #23
0
def get_shape_and_rank_nodes_by_port(port: Port,
                                     return_as_a_scalar: bool = True):
    """
    The function returns nodes producing shape and rank of the data from the desired port in order to use those
    operations on the middle/back phase
    :param port: Port object that specifies node output port
    :param return_as_a_scalar: boolean flag to return 1D or 0D rank
    :return: shape and rank nodes
    """
    input_node_name = port.node.soft_get('name', port.node.id)
    graph = port.node.graph

    shape = Shape(graph, dict(name=input_node_name + '/ShapeOf')).create_node()
    rank_1_d = Shape(graph,
                     dict(name=input_node_name + '/1dRankOf')).create_node()
    rank_1_d.in_port(0).connect(shape.out_port(0))
    shape.in_port(0).connect(port)
    if not return_as_a_scalar:
        return shape, rank_1_d

    rank = create_op_node_with_second_input(
        graph, Squeeze, int64_array([0]),
        {'name': input_node_name + '/0dRankOf'}, rank_1_d)
    return shape, rank
예제 #24
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)

        assert node.has_valid('output_type'), \
            'Rank node should have `output_type` attribute, but it`s not for node {}'.format(name)

        shape_of = Shape(graph, {
            'name': name + '/shape_of',
            'output_type': node.output_type
        }).create_node()
        rank_1d = Shape(graph, {
            'name': name + '/rank_of',
            'output_type': node.output_type
        }).create_node()
        rank_0d = create_op_node_with_second_input(
            graph, Squeeze, int64_array(0), {'name': name + '/0d_rank_of'},
            rank_1d)

        shape_of.out_port(0).connect(rank_1d.in_port(0))
        node.out_port(0).get_connection().set_source(rank_0d.out_port(0))
        node.in_port(0).get_connection().set_destination(shape_of.in_port(0))

        rename_nodes([(node, name + '/ToBeDeleted'), (rank_0d, name)])
예제 #25
0
    def replace_pattern(graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.soft_get('name', node.id)
        input_shape = node.in_port(0).data.get_shape()
        second_input_shape = node.in_port(1).data.get_shape()

        begin_mask = np.zeros(len(input_shape), dtype=np.int64)
        end_mask = np.zeros(len(input_shape), dtype=np.int64)

        for i in node.axes:
            end_mask[i] = np.int64(1)

        new_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64)

        ss = create_op_with_const_inputs(graph, StridedSlice,
                                         port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)},
                                         op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask,
                                                   'end_mask': end_mask, 'new_axis_mask': new_axis_mask,
                                                   'shrink_axis_mask': shrink_axis_mask,
                                                   'ellipsis_mask': ellipsis_mask})
        if input_shape.size == second_input_shape.size:
            end = Shape(graph, dict(name=name + '/End')).create_node()
            end.in_port(0).connect(node.in_port(1).get_source())
            ss.in_port(2).connect(end.out_port(0))
        else:
            shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source())
            end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True)
            if input_shape.size - 1 == node.axes[-1]:
                ss.in_port(2).connect(end_first_part.out_port(0))
            else:
                shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source())
                end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False,
                                                                 include_end=True)
                end = new_shape_node_from_shape_nodes([end_first_part, end_second_part])
                ss.in_port(2).connect(end.out_port(0))

        node.in_port(0).get_connection().set_destination(ss.in_port(0))
        node.in_port(1).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
예제 #26
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(
                op='BatchToSpace'):
            node.add_input_port(3, skip_if_exist=True)

            # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(
                graph, Transpose, {1: int64_array([1, 0])})
            node.in_port(2).get_connection().set_destination(
                transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split,
                                                     {1: int64_array(0)},
                                                     {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                node.in_port(port_ind + 2).connect(
                    split_pads.out_port(port_ind))
                node.in_port(port_ind + 2).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze,
                                                {1: int64_array([0])}))

            # add zeros/ones to related inputs to align it with data input
            in0_rank = Rank(graph, {
                'name': node.name + '/rank_0'
            }).create_node()
            in1_shape = Shape(graph, {
                'name': node.name + '/rank_1'
            }).create_node()

            diff_size = Sub(graph, {
                'name': node.name + '/sub_0'
            }).create_node()
            diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()
            const_begin = Const(graph, {
                'value': int64_array([1])
            }).create_node()
            const_pad_val = Const(graph, {
                'value': int64_array(1)
            }).create_node()

            block_shape = Pad(graph, {
                'name': node.name + '/aligned_block_shape',
                'mode': 'constant'
            }).create_node()

            # in case of SpaceToBatch begin = pads_begin, end = pads_end
            # in case of BatchToSpace begin = crops_begin, end = crops_end
            new_begin_name = '/aligned_pads_begin'
            new_end_name = '/aligned_pads_end'
            if node.type == 'BatchToSpace':
                new_begin_name = '/aligned_crops_begin'
                new_end_name = '/aligned_crops_end'

            begin = Pad(graph, {
                'name': node.name + new_begin_name,
                'mode': 'constant'
            }).create_node()
            end = Pad(graph, {
                'name': node.name + new_end_name,
                'mode': 'constant'
            }).create_node()

            in0_rank_1d = create_op_node_with_second_input(
                graph, Unsqueeze, int64_array([0]),
                {'name': node.name + '/1d_rank_of_0'}, in0_rank)

            node.in_port(0).get_source().connect(in0_rank.in_port(0))
            node.in_port(1).get_source().connect(in1_shape.in_port(0))
            in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
            in1_shape.out_port(0).connect(diff_size.in_port(1))
            diff_size.out_port(0).connect(diff.in_port(0))
            const_begin.out_port(0).connect(diff.in_port(1))
            const_pad_val.out_port(0).connect(block_shape.in_port(3))

            inputs_array = [block_shape, begin, end]
            for idx, input_to_node in enumerate(inputs_array):
                name_of_input_to_node = input_to_node.name
                node.in_port(idx + 1).get_connection().set_destination(
                    input_to_node.in_port(0))
                const_begin.out_port(0).connect(input_to_node.in_port(1))
                diff.out_port(0).connect(input_to_node.in_port(2))
                input_to_node.out_port(0).connect(node.in_port(idx + 1))
                convert = Cast(graph, {
                    'name': name_of_input_to_node + '/i64',
                    'dst_type': np.int64
                }).create_node()
                input_to_node.in_port(0).get_connection().insert_node(convert)
예제 #27
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if 1 not in node.in_ports() or node.in_port(1).disconnected():

            if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'):
                factor = Const(graph, {'value': np.array(node.factor)}).create_node()

                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(mul.in_port(0))
                factor.out_port(0).connect(mul.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                assert node.in_port(1).disconnected()
                mul.out_port(0).connect(node.in_port(1))

            else:
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))

                pads_value = node.pads_begin + node.pads_end
                pads_const = Const(graph, {'value': np.array(pads_value)}).create_node()
                add = Add(graph, {'name': node.name + '/pad_add'}).create_node()
                ss.out_port(0).connect(add.in_port(0))
                add.in_port(1).connect(pads_const.out_port(0))

                if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1:
                    shrink_factor = node.shrink_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'name': node.name + '/pre_shrink_sub_const',
                                          'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    sub.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'value': np.array(1 / shrink_factor),
                                          'name': node.name + 'shrink_factor_div_const'}).create_node()
                    div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    div.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1)
                                          }).create_node()
                    add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node()
                    div.out_port(0).connect(add.in_port(0))
                    const.out_port(0).connect(add.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    add.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1:
                    zoom_factor = node.zoom_factor
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \
                                            ' file (openvino/tools/mo/front/InterpolateNormalizer.py at the line {}).' \
                                            ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)

                    # Reshape methods can be different in some cases
                    # Commented out section represents reshape that used in deeplab-caffe
                    # Uncomment the following lines, if your model was trained with deeplab-caffe
                    # or have the same reshape method
                    # const = Const(graph, {'value': np.array(-1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node()
                    # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node()
                    # add.out_port(0).connect(sub.in_port(0))
                    # const.out_port(0).connect(sub.in_port(1))
                    #
                    # const = Const(graph, {'value': np.array(zoom_factor - 1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node()
                    # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node()
                    # sub.out_port(0).connect(mul.in_port(0))
                    # const.out_port(0).connect(mul.in_port(1))
                    #
                    # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node()
                    # add.out_port(0).connect(sum.in_port(0))
                    # mul.out_port(0).connect(sum.in_port(1))
                    #
                    # node.add_input_port(1, skip_if_exist=True)
                    # assert node.in_port(1).disconnected()
                    # sum.out_port(0).connect(node.in_port(1))

                    # Comment out the following lines if you use the reshape method from previous section
                    const = Const(graph, {'value': np.array(zoom_factor),
                                          'name': node.name + '/zoom_factor_mul_const'}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()

                    add.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    mul.out_port(0).connect(node.in_port(1))

                elif node.soft_get('width') != 0 and node.soft_get('height') != 0:
                    const = Const(graph, {'value': np.array([node.height, node.width])}).create_node()
                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    const.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1:
                    shrink_factor = node.shrink_factor
                    zoom_factor = node.zoom_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    const.out_port(0).connect(sub.in_port(1))

                    const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node()
                    div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    const.out_port(0).connect(div.in_port(1))

                    const = Const(graph, {'value': np.array(-1),
                                          'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node()
                    sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    const.out_port(0).connect(sum.in_port(1))

                    const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()
                    sum.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    mul.out_port(0).connect(sum.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    sum.out_port(0).connect(node.in_port(1))
        else:
            if node.soft_get('fw') == 'caffe':
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(1).get_connection().get_source()
                node.in_port(1).disconnect()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(node.in_port(1))
예제 #28
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                dict(name='do_reshape_classes'),
                                                                match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name', initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node()
        end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node()
        stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node()

        priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice',
                                           'begin_mask': int64_array([1, 1, 1]),
                                           'end_mask': int64_array([1, 0, 0]),
                                           'new_axis_mask': int64_array([0]),
                                           'shrink_axis_mask': int64_array([0]),
                                           'ellipsis_mask': int64_array([0])}).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes['variance']
        except:
            raise Error('There is no variance attribute in {} replacement config file `custom_attributes`'
                        ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(
            graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4},
            priors_scale_node)

        priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_')
                                ).create_node([(split_node, 2), (split_node, 0)])
        priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')
                                 ).create_node([(split_node, 3), (split_node, 1)])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {'name': 'concat_width_height',
                                  'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {'name': 'concat_width_height_width',
                                  'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2}
                                          ).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node(
            [concat_width_height_node, match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                   dict(name='reshape_regression'),
                                                                   applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
                 variance_encoded_in_target=0, background_label_id=1000))

        # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
        # correspond to original tensors and should be removed from output->Result edges
        out_nodes = []
        for out in range(match.outputs_count()):
            out_nodes.append(match.output_node(out)[0])
        clear_tensor_names_info(out_nodes)

        return {'detection_output_node': detection_output_node}
예제 #29
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        axis = node.axis
        input_shape_node = Shape(graph, {
            'name': name + '/ShapeOf'
        }).create_node()
        range_node = create_op_with_const_inputs(graph, Range, {
            0: mo_array(node.start),
            2: mo_array(node.step)
        }, {'name': name + '/Range'})
        node.in_port(0).get_connection().set_destination(
            input_shape_node.in_port(0))

        if axis is not None:
            '''
            Replace arange_like op to subgraph:
            Shape - Gather - Range
            '''
            gather_node = create_op_with_const_inputs(graph, Gather, {
                1: int64_array([axis]),
                2: int64_array(0)
            }, {'name': name + '/Gather'})
            input_shape_node.out_port(0).connect(gather_node.in_port(0))
            gather_node.out_port(0).connect(range_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                range_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (range_node, name)])
        else:
            r'''
            Replace arange_like op to subgraph:
                    |
                 ShapeOf ----------- | 
                    |                |
                 ReduceProd          |
                    |                |
                  Range              |
                    |                |
                 Reshape ----------- | 
                    |
            '''

            flattened_shape_node = create_op_with_const_inputs(
                graph, ReduceProd, {1: int64_array([0])}, {
                    'name': input_shape_node.name + '/ReduceProd',
                    'keep_dims': True
                })
            reshape_backward_node = Reshape(graph, {
                'name': name + '/Reshape_backward'
            }).create_node()

            input_shape_node.out_port(0).connect(
                flattened_shape_node.in_port(0))
            flattened_shape_node.out_port(0).connect(range_node.in_port(1))
            range_node.out_port(0).connect(reshape_backward_node.in_port(0))
            input_shape_node.out_port(0).connect(
                reshape_backward_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                reshape_backward_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (reshape_backward_node, name)])

        if node.repeat != 1:
            r"""
            First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
            Then repeats each value of the interval using Tile. After that we can get a longer interval
            so we reduce it with Slice.
            
            Sub-graph after Range node will be look like
            
            Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
            
            """

            if node.repeat < 1:
                raise Error(
                    "Unexpected value {} of the attribute 'repeat' for the node {}"
                    .format(node.repeat, name))

            div_node = create_op_with_const_inputs(
                graph, Div, {1: int64_array([node.repeat])},
                {'name': name + '/Divide'})
            add_node = create_op_with_const_inputs(
                graph, Add, {1: int64_array([1])},
                {'name': div_node.name + '/Add'})
            cast_node = Cast(graph, {
                'name': name + '/ConvertToI64',
                'dst_type': np.int64
            }).create_node()

            cast_node.out_port(0).connect(div_node.in_port(0))
            div_node.out_port(0).connect(add_node.in_port(0))
            range_node.in_port(1).get_connection().set_destination(
                cast_node.in_port(0))
            add_node.out_port(0).connect(range_node.in_port(1))

            tile_forward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1, 1])},
                {'name': range_node.name + '/ForwardReshape'})
            tile = create_op_with_const_inputs(
                graph, Tile, {1: int64_array([1, node.repeat])},
                {'name': tile_forward_reshape.name + '/Tile'})
            tile_backward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1])},
                {'name': tile.name + '/BackwardReshape'})
            slice_node = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': tile_backward_reshape.name + '/Slice'})

            tile_forward_reshape.out_port(0).connect(tile.in_port(0))
            tile.out_port(0).connect(tile_backward_reshape.in_port(0))
            tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
            slice_node.in_port(2).connect(div_node.in_port(0).get_source())

            range_node.out_port(0).get_connection().set_source(
                slice_node.out_port(0))
            range_node.out_port(0).connect(tile_forward_reshape.in_port(0))

            if axis is not None:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_node, name)])

        # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
        # we have to correct the stop value for the Range node if step != 1 or start != 0
        if node.step != 1:
            # If step attribute is not integer, we will generate an interval with a larger size and then reduce it
            # using Slice
            true_elements_count_port = range_node.in_port(1).get_source()
            mul_value = np.ceil(node.step) if node.step > 0 else np.floor(
                node.step)
            stop_value = create_op_with_const_inputs(
                graph,
                Mul,
                port_value_dict={1: mo_array(np.ceil(mul_value))},
                op_attrs={'name': range_node.name + '/Stop'})
            range_node.in_port(1).get_connection().insert_node(stop_value)

            slice_range_values = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': range_node.name + '/Slice'})
            slice_range_values.in_port(2).connect(true_elements_count_port)
            range_node.out_port(0).get_connection().insert_node(
                slice_range_values)

            if axis is not None and node.repeat == 1:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_range_values, name)])

        if node.start != 0:
            correct_stop_value = create_op_with_const_inputs(
                graph,
                Add,
                port_value_dict={1: mo_array(node.start)},
                op_attrs={'name': range_node.name + '/Correct_Stop'})
            range_node.in_port(1).get_connection().insert_node(
                correct_stop_value)

        # Range node supports only scalar inputs
        squeeze_node = create_op_with_const_inputs(
            graph,
            Squeeze,
            port_value_dict={1: int64_array(0)},
            op_attrs={"name": range_node.name + '/Stop/Squeeze'})
        range_node.in_port(1).get_connection().insert_node(squeeze_node)
예제 #30
0
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
    resize_name = resize.soft_get('name', resize.id)
    log.debug(
        "Converting of {} to Interpolate-4 is triggered for node {}.".format(
            resize.op, resize_name))

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs == 2, \
        "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)

    attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
                "the attribute align_corners must be False"
    assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
        attrs_msg.format(resize_name, resize.op)

    shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()

    ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([1]),
        2: int64_array([3]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/StridedSlice',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([1]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })

    div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()

    shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
    size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()

    size_to_float.out_port(0).connect(div_node.in_port(0))
    shape_to_float.out_port(0).connect(div_node.in_port(1))
    ss.out_port(0).connect(shape_to_float.in_port(0))
    shape.out_port(0).connect(ss.in_port(0))

    align_corners = resize.align_corners
    half_pixel_centers = resize.half_pixel_centers

    nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
    if align_corners:
        coordinate_transformation_mode = 'align_corners'
        if interpolation_mode == 'nearest':
            nearest_mode = 'round_prefer_ceil'
    elif half_pixel_centers:
        coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
    else:
        coordinate_transformation_mode = 'asymmetric'

    interpolate4 = create_op_with_const_inputs(
        graph, Interpolate, {3: int64_array([1, 2])}, {
            'name': resize_name + '/interpolate_4',
            'mode': interpolation_mode,
            'antialias': False,
            'coordinate_transformation_mode': coordinate_transformation_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'nearest_mode': nearest_mode,
            'cube_coeff': -0.75,
            'shape_calculation_mode': 'sizes',
            'version': 'opset4',
            'in_ports_count': 4,
        })

    resize_input_connection = resize.in_port(0).get_connection()
    resize_input_connection.set_destination(interpolate4.in_port(0))
    resize_input_connection.get_source().connect(shape.in_port(0))

    div_node.out_port(0).connect(interpolate4.in_port(2))

    sizes_connection = resize.in_port(1).get_connection()
    sizes_connection.set_destination(interpolate4.in_port(1))
    sizes_connection.get_source().connect(size_to_float.in_port(0))

    resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate4, resize_name)])