Пример #1
0
 def extract(cls, node: Node):
     Shape.update_node_stat(
         node, {
             'data_type':
             tf_dtype_extractor(node.pb.attr['out_type'].type, np.int32)
         })
     return cls.enabled
    def make_interpolate_reshapeable(interpolate, concat):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'

        output_shape = interpolate.out_port(0).data.get_shape()

        interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)]
        concat_axis = get_canonical_axis_index(output_shape, concat.axis)
        if concat_axis in interp_axes:
            return

        concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()]
        non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
        if len(non_interp_concat_srcs) == 0:
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(graph, Gather,
                                             {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)},
                                             {'name': shape.name + '/Gathered'}, shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
Пример #3
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        node.is_training = False

        shape = node.in_port(1).data.get_shape()
        assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name)

        bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32),
                                'override_output_shape': True}).create_node()
        bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32),
                               'override_output_shape': True}).create_node()
        node.in_port(3).get_connection().set_source(bn_mean.out_port(0))
        node.in_port(4).get_connection().set_source(bn_std.out_port(0))

        # save the original shape
        original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node()
        original_shape.in_port(0).connect(node.in_port(0).get_source())

        mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6),
                          'override_output_shape': True}).create_node()
        node.in_port(0).get_connection().insert_node(mvn)

        reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]),
                                                      {'override_output_shape': True,
                                                       'name': node.soft_get('name') + '/fused_batch_and_channels'})
        mvn.in_port(0).get_connection().insert_node(reshape_4d)

        # restore original shape
        reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape',
                                       'override_output_shape': True}).create_node()
        reshape_back.in_port(1).connect(original_shape.out_port(0))
        mvn.out_port(0).get_connection().insert_node(reshape_back)
Пример #4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        shape = Cast(
            graph, {
                'name':
                node_name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float_array([node.patch_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                2: float_array([1]),
                3: float_array([node.patch_stride])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(1).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Пример #6
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        source_connection = match['split'].in_port(0).get_connection()
        source_node = source_connection.get_source().node
        cast_node = match['cast']

        range_node = Range(graph, {
            'name': source_node.id + '/Range'
        }).create_node()
        start_node = Const(graph, {
            'name': range_node.id + '/Start',
            'value': int64_array(0)
        }).create_node()

        step_node = Const(graph, {
            'name': range_node.id + '/Step',
            'value': int64_array(1)
        }).create_node()
        input_shape_node = Shape(graph, {
            'name': start_node.id + '/Shape'
        }).create_node()
        input_shape_node.in_port(0).connect(source_node.out_port(0))

        limit_node_1D = node_to_get_batch_value(input_shape_node)
        limit_node = create_op_node_with_second_input(
            graph, Squeeze, int64_array([0]),
            {'name': source_node.id + '/batch_0D_value'}, limit_node_1D)

        range_node.in_port(0).connect(start_node.out_port(0))
        range_node.in_port(1).connect(limit_node.out_port(0))
        range_node.in_port(2).connect(step_node.out_port(0))
        cast_node.out_port(0).get_connection().set_source(
            range_node.out_port(0))

        graph.remove_nodes_from([node.id for node in match.values()])
Пример #7
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
        axis = node.axis

        if axis == 0:
            dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
        elif axis == 1:
            dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
            first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
                                                          {'keep_dims': True})
            second_dims = Const(graph, {'value': int64_array([-1])}).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
Пример #8
0
    def find_and_replace_pattern(self, graph: Graph):
        for roll_node in graph.get_op_nodes(op='Roll'):
            if not roll_node.in_port(2).disconnected():
                return
            node_name = roll_node.soft_get('name', roll_node.id)

            # reshape to 1d tensor
            reshape_to_1d = create_op_node_with_second_input(
                graph, Reshape, int64_array([-1]),
                {'name': node_name + '/reshape'})
            roll_node.in_port(0).get_connection().insert_node(reshape_to_1d)

            # add zero const as axes input to roll
            const_zero = Const(graph, {
                'value': int64_array([0]),
                'name': node_name + '/axes'
            }).create_node()
            const_zero.out_port(0).connect(roll_node.in_port(2))

            # reshape to original shape
            shape_of = Shape(graph, {
                'name': node_name + '/shape_of'
            }).create_node()
            roll_node.in_port(0).get_connection().add_destination(
                shape_of.in_port(0))
            reshape_to_orig_shape = Reshape(graph, {}).create_node()
            rename_nodes([(roll_node, node_name + '/roll'),
                          (reshape_to_orig_shape, node_name)])
            shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1))
            roll_node.out_port(0).get_connection().insert_node(
                reshape_to_orig_shape)
Пример #9
0
    def replace_pattern(self, graph: Graph, match: dict):
        matmul = match['matmul']
        reshape = match['reshape']
        other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1
        shape_source = match['matmul'].in_port(other_input_port_idx).get_source()
        initial_reshape_pattern = reshape.in_port(1).data.get_value()
        if len(initial_reshape_pattern) != 2:
            return

        reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id
        if reshape_is_A_input:
            idx = -1 if matmul.transpose_b else -2
        else:
            idx = -2 if matmul.transpose_a else -1
        idx = get_canonical_axis_index(initial_reshape_pattern, idx)

        shape_name = shape_source.node.soft_get('name', shape_source.node.id)
        shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node()
        shape.in_port(0).connect(shape_source)
        C = node_to_get_shape_value_of_indices(shape, [idx])
        N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node()

        if len(initial_reshape_pattern) == 2:
            if reshape_is_A_input:
                reshape_pattern = [C, N] if matmul.transpose_a else [N, C]
            else:
                reshape_pattern = [N, C] if matmul.transpose_b else [C, N]
            new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern)
            reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0))
        else:
            return
Пример #10
0
    def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'
        interp_axes = Interpolate.get_axes(interpolate)
        concat_axis = self.get_concat_axis(concat)

        if concat_axis is None or interp_axes is None \
                or np.any(interp_axes < 0) or concat_axis < 0 \
                or concat_axis in interp_axes:
            # checks that interpolate axes and concat axis are valid and do not intersect
            return

        non_interp_concat_srcs = self.get_non_interpolate_concat_sources(
            concat)
        if not len(non_interp_concat_srcs):
            # there is no Concat input to take input from
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {
            'name': src.node.soft_get('name', src.node.id) + '/Shape'
        }).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(
            graph,
            Gather, {
                1: np.array(interp_axes, dtype=np.int32),
                2: int64_array(0)
            }, {'name': shape.name + '/Gathered'},
            input_node=shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
    def squeeze_initial_states(graph: Graph, match: dict):
        """
        Squeeze input initial states of recurrent node to 2-D shape.
        """
        hidden_init_port = 5
        cell_init_port = 6

        rnn_layer = match['rnn_layer']
        # Add input ports to rnn_layer
        rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)

        assert hidden_init_port in rnn_layer.in_nodes()
        hidden_size = rnn_layer.hidden_size
        shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
        rnn_layer.in_port(0).get_source().connect(shape.in_port(0))

        batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
        new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
                                                   op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
                                                                 in_ports_count=2, axis=0), input_node=batch)
        reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
        new_dim.out_port(0).connect(reshape_h.in_port(1))
        rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)

        if rnn_layer.op == 'LSTM':
            assert cell_init_port in rnn_layer.in_nodes()

            reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
            new_dim.out_port(0).connect(reshape_c.in_port(1))
            rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
Пример #12
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        name = node.soft_get('name', node.id)

        graph.graph['cmd_params'].static_shape = False

        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
        end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
        stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()

        shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
        ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
        ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}

        node['need_shape_inference'] = True
        node['override_output_shape'] = True
        node['V10_infer'] = True
        unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'})
        naked_priorbox_name = name + '/naked_not_unsqueezed'
        rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)])

        node.out_port(0).get_connection().set_source(unsqueeze.out_port(0))
        node.out_port(0).connect(unsqueeze.in_port(0))
Пример #13
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2])}).create_node()
        end = Const(graph, {'value': np.array([4])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()

        shape_0 = Shape(graph, {
            'name': node.name + '/0_port',
            'stop_value_propagation': True
        }).create_node()
        ss_0 = StridedSlice(
            graph, {
                'name': node.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {
            'name': node.name + '/1_port',
            'stop_value_propagation': True
        }).create_node()
        ss_1 = StridedSlice(
            graph, {
                'name': node.name + '/ss_1_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
Пример #14
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(op='BatchToSpace'):
            node.add_input_port(3, skip_if_exist=True)

            # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])})
            node.in_port(2).get_connection().set_destination(transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                node.in_port(port_ind + 2).connect(split_pads.out_port(port_ind))
                node.in_port(port_ind + 2).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}))

            # add zeros/ones to related inputs to align it with data input
            in0_rank = Rank(graph, {'name': node.name + '/rank_0'}).create_node()
            in1_rank = Shape(graph, {'name': node.name + '/rank_1'}).create_node()

            diff_size = Sub(graph, {'name': node.name + '/sub_0'}).create_node()
            diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()

            const_begin = Const(graph, {'value': int64_array([1])}).create_node()
            const_pad_val = Const(graph, {'value': int64_array([1])}).create_node()

            block_shape = Pad(graph, {'name': node.name + '/aligned_block_shape', 'mode': 'constant'}).create_node()

            # in case of SpaceToBatch begin = pads_begin, end = pads_end
            # in case of BatchToSpace begin = crops_begin, end = crops_end
            new_begin_name = '/aligned_pads_begin'
            new_end_name = '/aligned_pads_end'
            if node.type == 'BatchToSpace':
                new_begin_name = '/aligned_crops_begin'
                new_end_name = '/aligned_crops_end'

            begin = Pad(graph, {'name': node.name + new_begin_name, 'mode': 'constant'}).create_node()
            end = Pad(graph, {'name': node.name + new_end_name, 'mode': 'constant'}).create_node()

            in0_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                           {'name': node.name + '/1d_rank_of_0'}, in0_rank)
            in1_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                           {'name': node.name + '/1d_rank_of_1'}, in1_rank)

            node.in_port(0).get_source().connect(in0_rank.in_port(0))
            node.in_port(1).get_source().connect(in1_rank.in_port(0))
            in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
            in1_rank_1d.out_port(0).connect(diff_size.in_port(1))
            diff_size.out_port(0).connect(diff.in_port(0))
            const_begin.out_port(0).connect(diff.in_port(1))
            const_pad_val.out_port(0).connect(block_shape.in_port(3))

            inputs_array = [block_shape, begin, end]
            for idx, input_to_node in enumerate(inputs_array):
                node.in_port(idx + 1).get_connection().set_destination(input_to_node.in_port(0))
                const_begin.out_port(0).connect(input_to_node.in_port(1))
                diff.out_port(0).connect(input_to_node.in_port(2))
                input_to_node.out_port(0).connect(node.in_port(idx + 1))
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
                                       'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                       'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = np.array([get_split_scale(split)], dtype=np.float32)
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def replace_pattern(self, graph: Graph, match: dict):
        if not self.is_applicable(match):
            return

        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id)
        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        d_idx = int(second_input_of_unsqueeze.value)
        axis = d_idx - 1

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis])

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        scale = int64_array([second_input_of_tile.value[d_idx]])
        float_scale = float32_array([second_input_of_tile.value[d_idx]])
        mul_node = create_op_with_const_inputs(
            graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'})

        axis_len_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = create_op_with_const_inputs(
            graph, Interpolate, {
                2: float_scale,
                3: int64_array([axis])
            }, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4,
                'maybe_part_of_sequence': True
            })
        mul_node.out_port(0).connect(interp_node.in_port(1))

        reshape_node = match['reshape']
        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = unsqueeze_node.in_port(0).get_connection()
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        unsqueeze_connection.get_source().connect(shape_node.in_port(0))
Пример #18
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        assert node.has_valid('output_type'), \
            'Size node should have `output_type` attribute, but it`s not for node {}'.format(name)

        shape = Shape(graph, {'name': name + '/Shape/', 'output_type': node.output_type}).create_node()
        node.in_port(0).get_connection().set_destination(shape.in_port(0))
        reduce_prod = create_op_node_with_second_input(
            graph, ReduceProd, int64_array([0]), {'name': shape.name + 'ReduceProd/', 'keep_dims': False}, shape)
        node.out_port(0).get_connection().set_source(reduce_prod.out_port(0))

        rename_nodes([(node, name + '/ToBeDeleted'), (reduce_prod, name)])
Пример #19
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {'out_ports_count': 2}, shape)
        node_pool_stride = Const(graph, {
            'value': int64_array([node.pool_stride])
        }).create_node()
        pow_node = create_op_node_with_second_input(graph, Pow,
                                                    int64_array([-1]))
        pow_node.in_port(0).connect(node_pool_stride.out_port(0))

        mul = Mul(graph, {}).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        const_1 = Const(graph, {'value': int64_array([1])}).create_node()

        concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(3).connect(mul.out_port(0))
        concat.in_port(2).connect(const_1.out_port(0))
        concat.in_port(1).connect(node_pool_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': '/Reshape/' + node.name
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node.name + '/Reshape/'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Пример #20
0
def create_zero_value_with_batch_from_input(input_out_port: Port,
                                            second_dim,
                                            precision=np.float):
    # create init_graph connected to ReadValue
    graph = input_out_port.node.graph
    input_name = input_out_port.node.name
    shape_of_input = Shape(graph, {
        'name': 'shape/' + input_name
    }).create_node()
    shape_of_input.in_port(0).connect(input_out_port)
    dim_for_get_batch = Const(
        graph, {
            'name': 'dim/crop_batch/' + shape_of_input.name,
            'value': int64_array([1]),
            'shape': int64_array([1])
        }).create_node()
    get_batch = Crop(
        graph, {
            'name': 'crop_batch/' + shape_of_input.name,
            'axis': int64_array([0]),
            'offset': int64_array([0])
        }).create_node()
    get_batch.in_port(0).connect(shape_of_input.out_port(0))
    get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
    mem_shape_2nd_dim = Const(
        graph, {
            'name': 'gifo_r_weights_shape/' + input_name,
            'value': int64_array([second_dim]),
            'shape': int64_array([1])
        }).create_node()
    mem_shape = Concat(
        graph, {
            'name': 'gather_memory_shape/' + input_name,
            'axis': 0,
            'in_ports_count': 2
        }).create_node()
    mem_shape.in_port(0).connect(get_batch.out_port(0))
    mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
    fill_value = Const(
        graph, {
            'name': 'fill_value/' + input_name,
            'value': np.array([0.0], precision),
            'shape': int64_array([1])
        }).create_node()
    init_value_prev_lstm_output = Broadcast(graph, {
        'name': 'init_value/' + input_name,
    }).create_node()
    init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
    init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
    return init_value_prev_lstm_output
Пример #21
0
 def resolve_minus2(self, shape_node, input_index, reshape_index, dims):
     rank_node = Shape(
         shape_node.graph,
         dict(name=shape_node.id +
              '/RankShapeMXReshapeMinus2')).create_node()
     rank_node.in_port(0).connect(shape_node.out_port(0))
     shape_values_node = get_shape_values_by_range_idxs(shape=shape_node,
                                                        rank=rank_node,
                                                        begin=input_index,
                                                        end=-1,
                                                        include_begin=True,
                                                        include_end=True)
     input_index = None
     reshape_index = reshape_index + 1
     return input_index, reshape_index, dims, shape_values_node
Пример #22
0
    def replace_op(self, graph: Graph, node: Node):
        shape = Shape(graph, {'name': node.name + '/Shape/'}).create_node()
        reduce_prod = ReduceProd(graph, {
            'name': shape.name + 'ReduceProd/',
            'keep_dims': False
        }).create_node()
        reduce_axis = Const(graph, {'value': int64_array([0])}).create_node()

        # Connect nodes
        node.in_port(0).get_connection().set_destination(shape.in_port(0))
        reduce_prod.in_port(0).get_connection().set_source(shape.out_port(0))
        reduce_prod.in_port(1).get_connection().set_source(
            reduce_axis.out_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [reduce_prod.id]
    def replace_sub_graph(self, graph: Graph, match: dict):
        if not check_applicability(match):
            return

        reshape = match['reshape']
        div_name = match['division'].name

        input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node()
        shape_of_reshape = reshape.in_port(1).get_connection().get_source().node.value
        c1, c2 = shape_of_reshape[1], shape_of_reshape[2]
        c = c1 * c2

        new_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([0, 0, 0, c1, c2]),
                                                       dict(name=div_name + '/first_reshape/MVN_T_'))
        permute_order = int64_array([0, 1, 2, 4, 3])
        first_permute = create_op_node_with_second_input(graph, Transpose, permute_order,
                                                         dict(name=div_name + '/first_permute/MVN_T_'), new_reshape)

        add = match['add']
        variance = match['variance']
        eps_port_num = 0 if add.in_port(0).get_connection().get_source().node.id != variance.id else 1
        eps = add.in_port(eps_port_num).get_connection().get_source().node
        mvn_node = create_op_with_const_inputs(graph, MVN, {1: int64_array([1, 2, 3])},
                                               dict(name=div_name + '/MVN/MVN_T_',
                                                    eps=eps.value, normalize_variance=1,
                                                    eps_mode='inside_sqrt'))
        first_permute.out_port(0).connect(mvn_node.in_port(0))

        second_permute = create_op_node_with_second_input(graph, Transpose, permute_order,
                                                          dict(name=div_name + '/second_permute/MVN_T_'), mvn_node)
        new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node()
        second_permute.out_port(0).connect(new_reshape2.in_port(0))
        gamma_val = np.reshape(match['gamma_identity'].in_port(0).get_connection().get_source().node.value,
                               int64_array([1, 1, 1, c]))
        new_mul = create_op_node_with_second_input(graph, Mul, gamma_val,
                                                   dict(name=match['mul'].name + '/MVN_T_'), new_reshape2)
        beta_val = np.reshape(match['beta_identity'].in_port(0).get_connection().get_source().node.value,
                              int64_array([1, 1, 1, c]))
        new_add2 = create_op_node_with_second_input(graph, Add, beta_val,
                                                    dict(name=match['add2'].name + '/MVN_T_'), new_mul)

        transpose_connection = match['transpose'].in_port(0).get_connection()
        before_transpose = transpose_connection.get_source().node
        transpose_connection.set_destination(new_reshape.in_port(0))
        input_shape.out_port(0).connect(new_reshape2.in_port(1))
        before_transpose.out_port(0).connect(input_shape.in_port(0))
        match['transpose2'].out_port(0).get_connection().set_source(new_add2.out_port(0))
Пример #24
0
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = int64_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name

    shape_node = Shape(graph,
                       dict(name=split_node_name + '/Shape_')).create_node()
    scales_node = Const(graph,
                        dict(name=split_node_name + '/scales_',
                             value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    slice_begin = Const(
        graph,
        dict(name=split_node_name + '/slice_begin_',
             value=int64_array([axis]))).create_node()
    slice_end = Const(
        graph,
        dict(name=split_node_name + '/slice_end_',
             value=int64_array([axis + 1]))).create_node()

    strided_slice_node = StridedSlice(
        graph, {
            'name': split_node_name + '/StridedSlice_',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0]),
        }).create_node([shape_node, slice_begin, slice_end])
    strided_slice_node.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(
        graph,
        dict(name=split_node_name + '/Interpolate_',
             axes=int64_array([axis]),
             mode='nearest')).create_node()
    mul_node.out_port(0).connect(interp_node.in_port(1))

    match['concat'].out_port(0).get_connection().set_source(
        interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
                                                     'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                                     'shrink_axis_mask': np.array([0]),
                                                     'ellipsis_mask': np.array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                      'axis': np.array(0)}).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
 def make_interpolate_reshapeable(interpolate):
     assert interpolate.soft_get('type') == 'Interpolate'
     axes = Interpolate.get_axes(interpolate)
     input_shape = interpolate.in_port(0).data.get_shape()
     output_shape = interpolate.out_port(0).data.get_shape()
     if not np.all(np.remainder(output_shape, input_shape) == 0) and \
             not np.all(np.remainder(input_shape, output_shape) == 0):
         return
     graph = interpolate.graph
     name = interpolate.soft_get('name', interpolate.id)
     shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
     shape.in_port(0).connect(interpolate.in_port(0).get_source())
     gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
                                          {'name': shape.name + '/Gathered'}, shape)
     multipliers = output_shape[axes] / input_shape[axes]
     mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
     interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
Пример #27
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        mxreshape = match['op']
        if not mxreshape.reverse:
            return

        shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node()
        forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                          dict(name=str(mxreshape.id) + '/ForwardUnsqueeze'))
        forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node()

        forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                        dict(name=str(mxreshape.id) + '/ForwardSqueeze'))
        reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node()
        shape_node.in_port(0).connect(mxreshape.in_port(0).get_source())
        mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0))

        forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0))
        forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0))
        forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0))
        reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0))

        reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)),
                                                              dict(name=str(mxreshape.id) + '/ReshapeShape'))
        if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0):
            reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape',
                                     dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node()

        reshape_shape_node.in_port(0).connect(reshape_node.out_port(0))

        backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node()
        backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                           dict(name=str(mxreshape.id) + '/BackwardUnsqueeze'))
        backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node()
        backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                         dict(name=str(mxreshape.id) + '/BackwardSqueeze'))
        backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node()

        backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0))
        backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0))
        backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0))

        backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0))

        mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Пример #29
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)

        shape_of = Shape(graph, {'name': name + '/shape_of'}).create_node()
        rank_1d = Shape(graph, {'name': name + '/rank_of'}).create_node()
        rank_0d = create_op_node_with_second_input(
            graph, Squeeze, int64_array([0]), {'name': name + '/0d_rank_of'},
            rank_1d)

        shape_of.out_port(0).connect(rank_1d.in_port(0))
        node.out_port(0).get_connection().set_source(rank_0d.out_port(0))
        node.in_port(0).get_connection().set_destination(shape_of.in_port(0))
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        flatten_node = FlattenONNX(graph, {
            'name': node_name + '/FlattenONNX_',
            'axis': node.axis
        }).create_node()
        shape_node = Shape(graph, {
            'name': node_name + '/ShapeOf_'
        }).create_node()
        logsoftmax_node = LogSoftmax(graph, {
            'name': node_name + '/LogSoftmax_',
            'axis': 1
        }).create_node()
        reshape_node = Reshape(graph, {}).create_node()

        rename_nodes([(node, node_name + '/delete'),
                      (reshape_node, node_name)])

        shape_node.out_port(0).connect(reshape_node.in_port(1))
        logsoftmax_node.out_port(0).connect(reshape_node.in_port(0))
        flatten_node.out_port(0).connect(logsoftmax_node.in_port(0))

        source = node.in_port(0).get_source()

        flatten_node.in_port(0).connect(source)
        shape_node.in_port(0).connect(source)

        return [reshape_node.id]