Esempio n. 1
0
    def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int):
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
            blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
                                              'value': int64_array(blank_values_arr)}).create_node()

            concat_in_ports_count = 3 if ellipsis_start != 0 else 2
            concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
                                    'in_ports_count': concat_in_ports_count}).create_node()

            if ellipsis_start != 0:
                split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0),
                                                                           2: int64_array([ellipsis_start, -1])},
                                                    {'name': node_name + '/split_for_{}_ellipsis'.format(input_name),
                                                     'out_ports_count': 2})
                node.in_port(i).get_connection().set_destination(split.in_port(0))

                concat.in_port(0).connect(split.out_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.in_port(2).connect(split.out_port(1))
            else:
                concat.in_port(0).connect(blank_values_node.out_port(0))
                node.in_port(i).get_connection().set_destination(concat.in_port(1))

            concat.out_port(0).get_connection().set_destination(node.in_port(i))
Esempio n. 2
0
    def extend_inputs(node: Node, num_insertions: int):
        graph = node.graph
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
            blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
                                              'value': int64_array(blank_values_arr)}).create_node()

            if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
                # concat already exists
                concat = node.in_port(i).get_source().node
                last_in_port = max(concat.in_ports().keys())
                assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \
                                                                        'should be connected'. \
                    format(concat.soft_get('name', node.id))

                concat.add_input_port(last_in_port + 1)
                concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0))
            else:
                # have to create concat
                concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
                                        'in_ports_count': 2}).create_node()
                node.in_port(i).get_connection().set_destination(concat.in_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.out_port(0).get_connection().set_destination(node.in_port(i))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
                                       'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
                                    'offset': np.array([0]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Esempio n. 4
0
    def replace_tdnn(self, graph: Graph, tdnn_node: Node):
        tdnn_name = tdnn_node.soft_get('name', tdnn_node.id)

        concat_node = Concat(graph, {'axis': 1}).create_node()
        rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'),
                      (concat_node, tdnn_name)])

        for offset_ind, t in enumerate(tdnn_node['time_offsets']):
            concat_node.add_input_port(offset_ind)
            if t != 0:
                memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t))
                memoryoffset_node = MemoryOffset(
                    graph, {
                        'name': memory_name,
                        't': t,
                        'pair_name': memory_name + '_out',
                        'has_default': False,
                        'splitted': False
                    }).create_node()

                tdnn_node.in_port(0).get_source().connect(
                    memoryoffset_node.in_port(0))
                memoryoffset_node.out_port(0).connect(
                    concat_node.in_port(offset_ind))
            else:
                # 0 time delay is not allowed in IE, it's meaningless
                # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset
                tdnn_node.in_port(0).get_source().connect(
                    concat_node.in_port(offset_ind))

        weights = tdnn_node['weights']
        fc_inputs = {1: weights}

        bias_term = False
        if tdnn_node.has_valid('biases'):
            assert len(tdnn_node['biases']) == weights.shape[0]
            fc_inputs.update({2: tdnn_node['biases']})
            bias_term = True

        fc_node = create_op_with_const_inputs(
            graph, FullyConnected, fc_inputs, {
                'name': tdnn_name + '/FC',
                'out-size': weights.shape[0],
                'transpose_weights': True,
                'bias_term': bias_term
            })

        concat_node.out_port(0).connect(fc_node.in_port(0))
        tdnn_node.in_port(0).disconnect()
        tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
Esempio n. 5
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['tile']
        name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None
        tiles = node.in_port(1).data.get_value()
        assert tiles is not None, "Undefined `repeats` (1st port input value) of Tile node '{}'".format(name)

        if input_shape.size == tiles.size:
            return

        if input_shape.size < tiles.size:
            unsqueeze = create_op_node_with_second_input(graph, Unsqueeze,
                                                         int64_array(list(range(tiles.size - input_shape.size))),
                                                         {'name': name + '/input_alignment',
                                                          'override_output_shape': True})
            node.in_port(0).get_source().connect(unsqueeze.in_port(0))
            node.in_port(0).get_connection().set_source(unsqueeze.out_port(0))
        else:
            const = Const(graph, {'name': name + '/tile_alignment_const',
                                  'value': np.ones([input_shape.size - tiles.size], dtype=np.int64)}).create_node()
            concat = Concat(graph, {'axis': 0, 'override_output_shape': True}).create_node()
            concat.add_input_port(0)
            concat.add_input_port(1)

            node.in_port(1).get_source().connect(concat.in_port(1))
            node.in_port(1).disconnect()
            concat.in_port(0).connect(const.out_port(0))

            node.in_port(1).connect(concat.out_port(0))
Esempio n. 6
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        concat_node = match['concat']
        concat_node['axis'] = 1
        concat_name = concat_node.soft_get('name', concat_node.id)

        concat_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([1, 2, -1]), op_attrs=dict(
            name=concat_name + '/Reshape'))
        split_node = create_op_node_with_second_input(graph, Split, int64_array(1), op_attrs=dict(
            name=concat_name + '/Split', num_splits=2), input_node=concat_reshape)
        split_node_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 4]), op_attrs=dict(
            name=split_node.name + '/Reshape'))
        split_node.out_port(0).connect(split_node_reshape.in_port(0))
        value = create_op_node_with_second_input(graph, Split, int64_array(1), op_attrs=dict(
            name=split_node_reshape.name + '/Split', num_splits=4), input_node=split_node_reshape)

        xmin, xmax = calculate_prior_box_value(value, value_to_div=value.out_port(2), value_to_add=value.out_port(0))
        ymin, ymax = calculate_prior_box_value(value, value_to_div=value.out_port(3), value_to_add=value.out_port(1))

        concat_slice_value = Concat(graph, dict(name=value.name + '/Concat', in_ports_count=4, axis=1)).create_node()
        for ind, node in enumerate([xmin, ymin, xmax, ymax]):
            concat_slice_value.in_port(ind).connect(node.out_port(0))

        reshape_concat_values = create_op_node_with_second_input(graph, Reshape, int64_array([1, 1, -1]),
                                                                 op_attrs=dict(name=concat_slice_value.name + '/Reshape'),
                                                                 input_node=concat_slice_value)
        concat = Concat(graph, dict(name=reshape_concat_values.name + '/Concat', in_ports_count=2, axis=1)).create_node()
        concat.in_port(0).connect(reshape_concat_values.out_port(0))
        concat.in_port(1).connect(split_node.out_port(1))

        match['detection_output'].in_port(2).get_connection().set_source(concat.out_port(0))
        concat_node.out_port(0).get_connection().set_destination(concat_reshape.in_port(0))
Esempio n. 7
0
    def replace_with_split_concat(node):
        graph = node.graph

        name = node.soft_get('name', node.id)
        axis = node.axis
        order = node.order

        split = create_op_with_const_inputs(graph, Split,
                                            {1: int64_array(axis)}, {
                                                'name': name + '/Split',
                                                'num_splits': order.size
                                            })
        concat = Concat(graph, {
            'name': name + '/Concat',
            'axis': axis,
            'in_ports_count': order.size
        }).create_node()

        for out_port_idx, in_port_idx in enumerate(order):
            split.out_port(out_port_idx).connect(concat.in_port(in_port_idx))

        node.out_port(0).get_connection().set_source(concat.out_port(0))
        node.in_port(0).get_connection().set_destination(split.in_port(0))

        graph.remove_node(node.id)
Esempio n. 8
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
                                                     'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                                     'shrink_axis_mask': np.array([0]),
                                                     'ellipsis_mask': np.array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                      'axis': np.array(0)}).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
Esempio n. 10
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['node']
        node_name = node.soft_get('name', node.id)

        connected_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        if len(connected_ports) == 2:
            axis = node.in_port(1).data.get_value()
        else:
            axis = node.axis

        assert axis is not None, 'The "axis" should be defined for node "{}"'.format(
            node_name)
        assert node.has_and_set(
            'output_type'), 'The data type is not set for node "{}"'.format(
                node_name)

        topk_mode = 'max' if node.op == 'ArgMax' else 'min'
        topk_node = TopK(
            graph, {
                'axis': axis,
                'mode': topk_mode,
                'sort': 'index',
                'remove_values_output':
                node.has_and_set('remove_values_output'),
                'index_element_type': node.output_type
            }).create_node()
        node.in_port(0).get_connection().set_destination(topk_node.in_port(0))
        if node.has_and_set(
                'out_max_val'
        ):  # in this mode the ArgMax produces tuples (max_ind, max_value)
            concat_node = Concat(graph, {
                'axis': 1,
                'name': node.name + '/Concat'
            }).create_node()
            concat_node.add_input_port(0, skip_if_exist=True)
            concat_node.add_input_port(1, skip_if_exist=True)
            topk_node.out_port(0).connect(concat_node.in_port(1))  # indices
            topk_node.out_port(1).connect(concat_node.in_port(0))  # values
            if not node.out_port(0).disconnected():
                node.out_port(0).get_connection().set_source(
                    concat_node.out_port(0))
        else:
            if not node.out_port(0).disconnected():
                node.out_port(0).get_connection().set_source(
                    topk_node.out_port(1))

        topk_node.in_port(1).connect(
            Const(graph, {
                'name': node.soft_get('name') + '/TopK',
                'value': node.top_k
            }).create_node().out_port(0))

        graph.remove_nodes_from([node.id, node.out_node(0).id])
Esempio n. 11
0
    def fuse_reduces(first_reduce, second_reduce):
        first_reduce_name = first_reduce.soft_get('name', first_reduce.id)
        second_reduce_name = second_reduce.soft_get('name', second_reduce.id)
        reduce_type = first_reduce.type

        assert first_reduce.type == second_reduce.type

        if len(first_reduce.out_port(0).get_destinations()) != 1:
            # data dependency
            return

        if first_reduce.keep_dims != second_reduce.keep_dims:
            return

        first_axes = first_reduce.in_port(1).data.get_value()
        second_axes = second_reduce.in_port(1).data.get_value()
        if first_axes is None or second_axes is None:
            # dynamic axes merging is not supported
            return

        if not first_reduce.keep_dims:
            if not np.all(first_axes > second_axes):
                # indexing of upper reduce input dimensions changed
                return

        graph = second_reduce.graph

        new_axes = Concat(
            graph, {
                'name': second_reduce_name + '/Axes',
                'axis': int64_array(0),
                'in_ports_count': 2,
                'override_output_shape': True
            }).create_node()
        new_axes.in_port(0).connect(first_reduce.in_port(1).get_source())
        new_axes.in_port(1).connect(second_reduce.in_port(1).get_source())

        first_reduce.in_port(
            0).get_source().node['need_shape_inference'] = True
        first_reduce.in_port(
            0).get_source().node['override_output_shape'] = True

        second_reduce.in_port(1).get_connection().set_source(
            new_axes.out_port(0))

        first_reduce.out_port(0).get_connection().set_source(
            first_reduce.in_port(0).get_connection().get_source())
        first_reduce.in_port(1).disconnect()
        graph.remove_node(first_reduce.id)

        log.debug(
            '{0} nodes {1} and {2} were fused to a single {2} node with updated axes input'
            ''.format(reduce_type, first_reduce_name, second_reduce_name))
Esempio n. 12
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {'out_ports_count': 2}, shape)
        node_pool_stride = Const(graph, {
            'value': int64_array([node.pool_stride])
        }).create_node()
        pow_node = create_op_node_with_second_input(graph, Pow,
                                                    int64_array([-1]))
        pow_node.in_port(0).connect(node_pool_stride.out_port(0))

        mul = Mul(graph, {}).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        const_1 = Const(graph, {'value': int64_array([1])}).create_node()

        concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(3).connect(mul.out_port(0))
        concat.in_port(2).connect(const_1.out_port(0))
        concat.in_port(1).connect(node_pool_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': '/Reshape/' + node.name
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node.name + '/Reshape/'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Esempio n. 13
0
def create_zero_value_with_batch_from_input(input_out_port: Port,
                                            second_dim,
                                            precision=np.float):
    # create init_graph connected to ReadValue
    graph = input_out_port.node.graph
    input_name = input_out_port.node.name
    shape_of_input = Shape(graph, {
        'name': 'shape/' + input_name
    }).create_node()
    shape_of_input.in_port(0).connect(input_out_port)
    dim_for_get_batch = Const(
        graph, {
            'name': 'dim/crop_batch/' + shape_of_input.name,
            'value': int64_array([1]),
            'shape': int64_array([1])
        }).create_node()
    get_batch = Crop(
        graph, {
            'name': 'crop_batch/' + shape_of_input.name,
            'axis': int64_array([0]),
            'offset': int64_array([0])
        }).create_node()
    get_batch.in_port(0).connect(shape_of_input.out_port(0))
    get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
    mem_shape_2nd_dim = Const(
        graph, {
            'name': 'gifo_r_weights_shape/' + input_name,
            'value': int64_array([second_dim]),
            'shape': int64_array([1])
        }).create_node()
    mem_shape = Concat(
        graph, {
            'name': 'gather_memory_shape/' + input_name,
            'axis': 0,
            'in_ports_count': 2
        }).create_node()
    mem_shape.in_port(0).connect(get_batch.out_port(0))
    mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
    fill_value = Const(
        graph, {
            'name': 'fill_value/' + input_name,
            'value': np.array([0.0], precision),
            'shape': int64_array([1])
        }).create_node()
    init_value_prev_lstm_output = Broadcast(graph, {
        'name': 'init_value/' + input_name,
    }).create_node()
    init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
    init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
    return init_value_prev_lstm_output
    def replace_pattern(self, graph: Graph, match: dict):
        concat_node = match['concat']
        sources_of_ports = [
            concat_node.in_port(i).get_connection().get_source()
            for i in concat_node.in_ports()
        ]
        # If 'concat' is ConcatV2 layer from TF, then this layer initially had input 'axis' as the last input.
        # But then this input was deleted and the attribute 'axis' was added. Hence, the last port source can
        # be None in such case.
        sources_of_ports = [s for s in sources_of_ports if s is not None]

        input_nodes = [s.node for s in sources_of_ports]
        if not all(n.has_valid('type') for n in input_nodes):
            return

        saved_ports = []
        disconnected_ports = []

        for port_num, node in enumerate(input_nodes):
            if node.soft_get('type') == 'Const' and len(
                    node.shape) > 1 and any(i == 0 for i in node.shape):
                disconnected_ports.append(port_num)
            else:
                saved_ports.append(port_num)

        if not saved_ports or not disconnected_ports:
            return

        if len(saved_ports) == 1:
            before_concat = concat_node.in_port(
                saved_ports[0]).get_connection().get_source()
            concat_node.out_port(0).get_connection().set_source(before_concat)
            return

        new_concat_attrs = concat_node.attrs().copy()
        new_concat_attrs['name'] = concat_node.name + '/Concat_'
        new_concat_attrs['in_ports_count'] = len(saved_ports)
        new_concat_node = Concat(graph, attrs=new_concat_attrs).create_node()

        for new_port_num, old_port_num in enumerate(saved_ports):
            concat_node.in_port(old_port_num).get_connection().set_destination(
                new_concat_node.in_port(new_port_num))

        for p in disconnected_ports:
            concat_node.in_port(p).disconnect()

        concat_node.out_port(0).get_connection().set_source(
            new_concat_node.out_port(0))
Esempio n. 15
0
    def replace_pattern(self, graph: Graph, match: dict):

        merge = match['merge']
        power = Pow(graph, {
            'name': merge.name + '/reciprocal_',
            'type': 'PNORM'
        }).create_node()
        const1 = Const(graph, {
            'value': -1.0,
            'name': merge.name + '/negate_const'
        }).create_node()
        merge.in_port(0).get_connection().set_destination(power.in_port(0))
        const1.out_port(0).connect(power.in_port(1))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(merge.significant, merge.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(10.0, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            power.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(10.0, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = Cast(
                graph, {
                    'name': merge.name + '/Cast_' + idx.__str__(),
                    'dst_type': np.float32
                }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': False,
                'in_ports_count': 2,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array([1, 5]),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
Esempio n. 16
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        in_shape = node.in_port(0).data.get_shape().copy()
        memory_element = in_shape[1] - node.const_dim
        memory_size = memory_element * len(node.context)

        memory_pair_id = unique_id('id')
        # Memory(in)
        input_memory = Memory(
            graph, {
                'name': 'prev_splice_memory',
                'id': memory_pair_id,
                'index': 1,
                'size': 2,
                'shape': int64_array([memory_size])
            }).create_node()
        # Memory(in)  \
        #             Crop
        # Input(temp) /
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([memory_element]),
                'dim': int64_array([memory_size - memory_element])
            }).create_node()
        crop.in_port(0).connect(input_memory.out_port(0))

        # Crop   \
        #         Concat
        # Input  /
        concat_node = Concat(graph, {
            'name': 'Splice_Concat',
            'in_ports_count': 2,
            'axis': 1
        }).create_node()
        concat_node.in_port(0).connect(crop.out_port(0))

        # Concat -> Memory(out)
        mem_out = Memory(
            graph, {
                'name': 'out_splice_memory',
                'id': memory_pair_id,
                'index': 0,
                'size': 2,
                'shape': int64_array([memory_size])
            }).create_node()
        mem_out.in_port(0).connect(concat_node.out_port(0))
        Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))

        if node.const_dim != 0:
            memory_element_constdim = node.const_dim
            memory_size_constdim = memory_element_constdim * len(node.context)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: int64_array(1),
                    2: int64_array([memory_element, memory_element_constdim])
                }, {
                    'name': node.id + '_split_const',
                    'out_ports_count': 2
                })

            split.out_port(0).connect(concat_node.in_port(1))

            # create separate splice construction for const_dim
            memory_pair_id = unique_id('memory_for_const_dim')
            input_memory_const_dim = Memory(
                graph, {
                    'name': 'const_dim_in_memory',
                    'id': memory_pair_id,
                    'index': 1,
                    'size': 2,
                    'shape': int64_array([memory_size_constdim])
                }).create_node()
            crop_const_dim = Crop(
                graph, {
                    'name':
                    'const_dim_crop',
                    'axis':
                    int64_array([1]),
                    'offset':
                    int64_array([memory_element_constdim]),
                    'dim':
                    int64_array(
                        [memory_size_constdim - memory_element_constdim])
                }).create_node()
            crop_const_dim.in_port(0).connect(
                input_memory_const_dim.out_port(0))

            concat_node_const_dim = Concat(graph, {
                'name': 'const_dim_concat',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat_node_const_dim.in_port(0).connect(
                crop_const_dim.out_port(0))

            mem_out_const_dim = Memory(
                graph, {
                    'name': 'const_dim_out_memory',
                    'id': memory_pair_id,
                    'index': 0,
                    'size': 2,
                    'shape': int64_array([memory_size_constdim])
                }).create_node()
            mem_out_const_dim.in_port(0).connect(
                concat_node_const_dim.out_port(0))
            Result(graph).create_node().in_port(0).connect(
                mem_out_const_dim.out_port(0))

            # connect splice to Split as begin and Concat as the end
            split.out_port(1).connect(concat_node_const_dim.in_port(1))
            crop_first = Crop(
                graph, {
                    'name': 'const_dim_crop_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([memory_element_constdim])
                }).create_node()
            crop_first.in_port(0).connect(concat_node_const_dim.out_port(0))

            concat_const = Concat(graph, {
                'name': node.id + '_concat_const',
                'axis': 1,
                'in_ports_count': 2
            }).create_node()
            concat_const.in_port(1).connect(crop_first.out_port(0))
            concat_const.in_port(0).connect(concat_node.out_port(0))

            node.in_port(0).get_connection().set_destination(split.in_port(0))
            node.out_port(0).get_connection().set_source(
                concat_const.out_port(0))
        else:
            node.in_port(0).get_connection().set_destination(
                concat_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                concat_node.out_port(0))

        # to avoid re-inference of shape and touching in next replacements
        graph.remove_node(node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
    def replace_sub_graph(self, graph: Graph, match: dict):
        # obtain references to necessary nodes and their names
        fill = match['fill']
        dims = match['dims']
        strided_slice = match['strided_slice']
        strided_slice_1 = match['strided_slice_1']
        ctc_greedy_decoder = match['decoder']
        cast = match['cast']
        sparse_to_dense = match['sparse_to_dense']
        strided_slice_name = strided_slice.soft_get('name', strided_slice.id)
        strided_slice_1_name = strided_slice_1.soft_get(
            'name', strided_slice_1.id)
        ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get(
            'name', ctc_greedy_decoder.id)
        sparse_to_dense_name = sparse_to_dense.soft_get(
            'name', sparse_to_dense.id)

        # unsqueeze scalar values with batch size and time dimension
        unsqueeze_batch_size = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array(0)},
            {'name': strided_slice_name + '/Unsqueeze'})
        dims.in_port(0).get_connection().set_destination(
            unsqueeze_batch_size.in_port(0))
        unsqueeze_time_size = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array(0)},
            {'name': strided_slice_1_name + '/Unsqueeze'})
        fill.in_port(1).get_connection().set_destination(
            unsqueeze_time_size.in_port(0))

        # compute a sequence mask shape [T, N] required for CTCGreedyDecoder
        seq_mask_shape = Concat(
            graph, {
                'axis': 0,
                'in_ports_count': 2,
                'name': ctc_greedy_decoder_name + '/SequenceMaskShape'
            }).create_node()
        seq_mask_shape.in_port(0).connect(unsqueeze_time_size.out_port(0))
        seq_mask_shape.in_port(1).connect(unsqueeze_batch_size.out_port(0))

        # compute a sequence mask
        sequence_mask = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([1.0], dtype=np.float)}, {
                'mode': 'numpy',
                'name': ctc_greedy_decoder_name + '/SequenceMask'
            })
        sequence_mask.in_port(1).connect(seq_mask_shape.out_port(0))

        # create CTCGreedyDecoder with the sequence mask instead of sequence length
        ctc_greedy_decoder.in_port(1).disconnect()
        ctc_greedy_decoder.in_port(1).connect(sequence_mask.out_port(0))

        # remove fill and pack nodes since they are now in unconnected component
        graph.remove_nodes_from([fill.id, dims.id])

        # transform opset CTCGreedyDecoder output to TensorFlow's one that has a shape [N, T]
        # opset CTCGreedyDecoder has an output with a shape [N, T, 1, 1]
        squeeze_dec_seq = create_op_with_const_inputs(
            graph, Squeeze, {1: int64_array([2, 3])},
            {'name': sparse_to_dense_name})
        squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_to_int = Cast(graph, {
            'name': sparse_to_dense_name + '/CastToInt',
            'dst_type': np.int32
        }).create_node()
        cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))

        # preserve output name from original graph
        rename_nodes([(sparse_to_dense,
                       sparse_to_dense_name + '/AbandonedName'),
                      (cast_to_int, sparse_to_dense_name)])

        # set output of the new sub-graph as a source for SparseToDense consumer
        sparse_to_dense.out_port(0).get_connection().set_source(
            cast_to_int.out_port(0))

        # cleanup a graph
        graph.remove_nodes_from([cast.id, sparse_to_dense.id])
Esempio n. 19
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
                                                               ('mem_in_data', dict(shape=int64_array([context_len]))),
                                                               ('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
                                                                                    offset=int64_array([1]),
                                                                                    dim=int64_array([context_len - 1]))),
                                                               ('crop_mem_in_data', dict()),
                                                               ('concat', dict(op='Concat', axis=1)),
                                                               ('concat_data', dict()),
                                                               ('const_1', dict(op='Const')),
                                                               ('const_1_data', dict()),
                                                               ('mem_out', dict(op='Assign')),
                                                               ('crop_out', dict(op='Crop', axis=int64_array([1]),
                                                                                 offset=int64_array([0]),
                                                                                 dim=int64_array([1]))),
                                                               ('crop_out_data', dict()),
                                                               ('select', dict(op='Select'))
                                                               ],
                                                 edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                                                        ('crop_mem_in', 'crop_mem_in_data'),
                                                        ('crop_mem_in_data', 'concat', {'in': 0}),
                                                        ('const_1', 'const_1_data'),
                                                        ('const_1_data', 'concat', {'in': 1}),
                                                        ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                                                        ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                                                        ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(graph, {'name': 'iteration_number',
                                        'variable_id': 'iteration_' + node.name}).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
                                     'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
            concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(graph, {'name': 'iteration_number_out',
                                    'variable_id': 'iteration_' + node.name}).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
                                    'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Esempio n. 20
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in',
                    dict(op='Memory',
                         index=1,
                         shape=int64_array([context_len]))),
                   ('mem_in_data', dict()),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()),
                   ('mem_out',
                    dict(op='Memory',
                         index=0,
                         shape=int64_array([context_len]))),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            mem_out = Memory(
                graph, {
                    'name': 'iteration_number',
                    'size': 2,
                    'index': 1,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len]),
                    'dst_type': np.int32
                }).create_node()
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Memory(
                graph, {
                    'name': 'iteration_number_out',
                    'size': 2,
                    'index': 0,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len])
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        select_node.in_port(0).connect(input_port)
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(
            input_port, in_shape[1] * node_t)
        memory_out = ReadValue(graph, {
            'name': pair_name,
            'variable_id': node_name + pair_name
        }).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(
                graph, {
                    'name': 'Memory_crop',
                    'dim': np.array([in_shape[1] * (node_t - 1)]),
                    'offset': np.array([in_shape[1]]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(
                graph, {
                    'name': 'Memory_crop_out',
                    'dim': np.array([in_shape[1]]),
                    'offset': np.array([0]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        if not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method",
                extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Esempio n. 22
0
    def replace_pattern(self, graph: Graph, match: dict):

        const = 0.99
        merge = match['merge']
        digits = significant_digits()
        pnorm = Power(
            graph, {
                'name': merge.name + '/reciprocal_',
                'type': 'PNORM',
                'significant': digits[0],
                'to_significant': digits[1],
                'scale': 1,
                'shift': 0,
                'power': get_power_attr()
            }).create_node()
        merge.in_port(0).get_connection().set_destination(pnorm.in_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)
        in_shape.insert(0, 1)

        reshape1 = Reshape(graph, {
            'name': merge.name + '/Reshape_Node1'
        }).create_node()
        reshape_dim1 = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim1'
        }).create_node()
        pnorm.out_port(0).connect(reshape1.in_port(0))
        reshape1.in_port(1).connect(reshape_dim1.out_port(0))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(pnorm.significant, pnorm.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(const, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            reshape1.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(const, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = ExpOp(graph, {
                'name': merge.name + '/Exp_' + idx.__str__()
            }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': True,
                'in_ports_count': 2,
                'shape': in_shape,
                'axis': 0,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
Esempio n. 23
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {
                    'name': node_name
                }).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {
                    'name': node_name
                }).create_node()
                node.in_port(2).get_connection().set_destination(
                    embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(
                embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(
                embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(
                embedding_bag.out_port(0))
            if len(node.in_ports()
                   ) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(4))

                    weights_shape_node = Shape(
                        graph, {
                            'name': node_name + '/WeightsShape'
                        }).create_node()

                    weights_rank_node = Rank(graph, {
                        'name': node_name + '/WeightsRank'
                    }).create_node()
                    last_dim_node = get_canonical_axis_index_node(
                        weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(
                        weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(
                        weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(
                        graph, Broadcast, {0: int64_array([0])},
                        {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(
                        weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(
                        graph, Unsqueeze, {1: int64_array(0)},
                        {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(
                        zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(
                        graph, {
                            'axis': 0,
                            'in_ports_count': 2,
                            'name': node_name + '/Concat'
                        }).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(
                        weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(
                        default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(
                        embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(
                        embedding_bag.in_port(3))
    def replace_timeheightconv(self, graph: Graph, node: Node):
        req_time_offsets = node.soft_get('time_offsets')
        offsets = node.soft_get("offsets", [[]])
        all_time_offsets = list(set(offsets[:, 0]))
        all_time_offsets.sort()
        in_name = node.soft_get('name', node.id)
        rename_node(node, in_name + '/to_delete')

        # create memoryoffsets for context gathering
        # we need concat if time offsets more than 1
        concat = Concat(graph,
                        attrs={
                            'name': in_name + '/Concat',
                            'in_ports_count': len(all_time_offsets)
                        }).create_node()
        i = 0
        for t in all_time_offsets:
            # if time offset included in required_time_offsets we don't need default value
            has_default = t not in req_time_offsets
            memoff = MemoryOffset(graph,
                                  attrs={
                                      'name':
                                      in_name + '/MemoryOffset_' + str(i),
                                      't':
                                      t,
                                      'has_default':
                                      has_default,
                                      'splitted':
                                      False,
                                      'pair_name':
                                      in_name + '/MemoryOffset_pair_' + str(i)
                                  }).create_node()
            concat.in_port(i).connect(memoff.out_port(0))
            memoff.in_port(0).connect(node.in_port(0).get_source())
            i = i + 1

        stride = node.soft_get("height_subsample", 1)

        kernel = int64_array([0, 0])
        kernel[0] = len(set(offsets[:, 0]))
        kernel[1] = len(set(offsets[:, 1]))

        pad_h = int64_array([0, 0])
        pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
        pad_h[1] = stride * node.height_out - (node.height_in -
                                               max([max(offsets[:, 1]), 0]))

        dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (
            kernel[0] - 1) if kernel[0] > 1 else 1
        dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (
            kernel[1] - 1) if kernel[0] > 1 else 1

        conv_attrs = {
            'name':
            in_name,
            'output':
            node['out_channels'],
            'height_in':
            node.height_in,
            'bias_term':
            None,
            'pad':
            int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
            'pad_spatial_shape':
            int64_array([[0, 0], pad_h]),
            'dilation':
            int64_array([1, 1, dilation_t, dilation_h]),
            'kernel':
            int64_array(
                [node.out_channels, node.in_channels, kernel[0], kernel[1]]),
            'stride':
            int64_array([1, 1, 1, stride]),
            'kernel_spatial':
            kernel,
            'input_feature_channel':
            1,
            'output_feature_channel':
            0,
            'channel_dims':
            int64_array([1]),
            'spatial_dims':
            int64_array([2, 3]),
            'batch_dims':
            int64_array([0]),
            'kernel_spatial_idx':
            int64_array([2, 3]),
            'group':
            1,
            'reshape_kernel':
            True,
            'bias_addable':
            True,
        }
        conv = Convolution(graph, attrs=conv_attrs).create_node()
        conv.in_port(0).connect(concat.out_port(0))
        conv.in_port(1).connect(node.in_port(1).get_source())

        # change layout for weights from OHWI to OIHW
        # in future should be replaced by common Permute mechanics
        weights = conv.in_port(1).get_source().node.value
        weights = weights.reshape(
            int64_array([node.out_channels, -1, node.in_channels]))
        weights = weights.transpose(int64_array([0, 2, 1]))
        weights = weights.flatten()
        conv.in_port(1).get_source().node.value = weights

        conv.in_port(2).connect(node.in_port(2).get_source())
        node.out_port(0).get_connection().set_source(conv.out_port(0))
        graph.remove_node(node.id)