Esempio n. 1
0
    def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape,
                                               condition_value, else_value, than_value, output_value):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape),
                                                                                    'value': condition_value}),
                                                                ('else_data', {'shape': np.array(else_data_shape),
                                                                               'value': else_value}),
                                                                ('than_data', {'shape': np.array(than_data_shape),
                                                                               'value': than_value}),
                                                                ('select_output',
                                                                 {'shape': np.array(select_output_shape),
                                                                  'value': None})
                                                                ])

        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                           update_nodes_attributes=[
                                               ('condition_data', {'shape': np.array(else_data_shape),
                                                                   'value': condition_value}),
                                               ('else_data', {'shape': np.array(else_data_shape),
                                                              'value': else_value}),
                                               ('than_data', {'shape': np.array(than_data_shape),
                                                              'value': than_value}),
                                               ('select_output', {'shape': np.array(select_output_shape),
                                                                  'value': output_value})])

        node = Node(graph, 'select')
        Select.infer(node)

        if else_value is not None and than_value is not None:
            (flag, resp) = compare_graphs(graph, graph_ref, 'select_output', check_op_attrs=True)
            self.assertTrue(flag, resp)
        self.assertTrue(np.array_equal(graph.nodes['select_output']['value'], graph_ref.nodes['select_output']['value']))
Esempio n. 2
0
    def test_select_infer_condition_true(self):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                       edges_with_attrs=self.edges,
                                       update_nodes_attributes=[
                                           ('condition', {
                                               'value': np.array([True])
                                           }),
                                           ('select_output', {
                                               'shape': np.array([2, 2]),
                                               'value': np.ones((2, 2))
                                           })
                                       ])

        # We should propagate shapes and values
        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                           edges_with_attrs=self.edges,
                                           update_nodes_attributes=[
                                               ('select_output', {
                                                   'shape': np.array([2, 2]),
                                                   'value': np.ones((2, 2))
                                               })
                                           ])

        tested_class = Select(graph=graph, attrs={})

        node = Node(graph, 'select')
        tested_class.infer(node)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'select_output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Esempio n. 3
0
    def replace_pattern(self, graph: Graph, match: dict):
        condition = match['cond_data']
        true_value = match['Switch_2_input']
        false_value = match['identity_data']

        select = Select(
            graph, dict(name=match['Merge'].name + '/Select/',
                        format='tf')).create_node(
                            inputs=[condition, true_value, false_value])

        match['Merge'].out_port(0).get_connection().set_source(
            select.out_port(0))

        # Reconnect inputs to some_op
        op = match['some_op']
        assert 1 in op.in_ports() and 0 in op.in_ports()

        op.in_port(0).disconnect()
        op.in_port(1).disconnect()
        match['Switch'].in_port(0).get_connection().set_destination(
            op.in_port(0))
        match['Switch_1'].in_port(0).get_connection().set_destination(
            op.in_port(1))

        graph.remove_nodes_from(nodes=[
            match['Switch_1'].id, match['Switch'].id, match['Switch_2'].id,
            match['Merge'].id
        ])
Esempio n. 4
0
    def test_select_infer_condition_shapes_broadcast(self, else_data_shape, than_data_shape, select_output_shape):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('else_data', {'shape': np.array(else_data_shape),
                                                                               'value': np.zeros(else_data_shape, dtype=np.float)}),
                                                                ('than_data', {'shape': np.array(than_data_shape),
                                                                               'value': np.zeros(than_data_shape, dtype=np.float)}),
                                                                ('select_output', {'shape': np.array(select_output_shape),
                                                                                   'value': np.zeros(select_output_shape, dtype=np.float)})
                                                                ])

        # We should propagate shapes and values
        graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                           update_nodes_attributes=[
                                               ('else_data', {'shape': np.array(else_data_shape),
                                                              'value': np.zeros(else_data_shape, dtype=np.float)}),
                                               ('than_data', {'shape': np.array(than_data_shape),
                                                              'value': np.zeros(than_data_shape, dtype=np.float)}),
                                               ('select_output', {'shape': np.array(select_output_shape), 'value': np.zeros(select_output_shape)})])

        tested_class = Select(graph=graph, attrs={})

        node = Node(graph, 'select')
        tested_class.infer(node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'select_output', check_op_attrs=True)
        self.assertTrue(flag, resp)
Esempio n. 5
0
    def test_select_infer_assert_shapes(self):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('else_data', {'shape': np.array([3, 3]), 'value':np.zeros((3, 3))})])

        tested_class = Select(graph=graph, attrs={})

        node = Node(graph, 'select')
        with self.assertRaisesRegex(AssertionError, "Input shape do not broadcast"):
            tested_class.infer(node)
Esempio n. 6
0
    def test_select_infer_assert_shapes(self):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('else', {'shape': np.array([3,3]), 'value':np.zeros((3,3))})])

        tested_class = Select(graph=graph, attrs={})

        node = Node(graph, 'select')
        with self.assertRaisesRegex(AssertionError, "TensorFlow \'Select\' operation has 3 inputs: \'condition\',"
                                                    " \'then\' and \'else\' tensors.\'then\' and \'else\' tensors"
                                                    " must have the same shape by TensorFlow reference"):
            tested_class.infer(node)
Esempio n. 7
0
    def build_select_graph_and_infer(condition_value,
                                     then_value,
                                     else_value,
                                     out_value,
                                     condition_shape=None,
                                     then_shape=None,
                                     else_shape=None,
                                     out_shape=None,
                                     auto_broadcast='numpy',
                                     fw_format=None):
        if then_value is not None:
            then_shape = int64_array(then_value.shape)
        if else_value is not None:
            else_shape = int64_array(else_value.shape)

        nodes = {
            **valued_const_with_data('then', then_value, then_shape),
            **valued_const_with_data('else', else_value, else_shape),
            **valued_const_with_data('condition', condition_value, condition_shape),
            **regular_op_with_empty_data(
                'select', {
                    'op': 'Select',
                    'auto_broadcast': auto_broadcast,
                    'format': fw_format
                }),
            **result('out'),
        }
        edges = [
            *connect('condition', '0:select'),
            *connect('then', '1:select'),
            *connect('else', '2:select'),
            *connect('select', 'out'),
        ]
        graph = build_graph(nodes, edges)

        select_node = Node(graph, 'select')
        Select.infer(select_node)

        select_out_node = Node(graph, 'select_d')

        value_desc = 'values'
        ref_val = out_value
        actual_val = select_out_node['value']
        if out_shape is not None:
            value_desc = 'shapes'
            ref_val = out_shape
            actual_val = select_out_node['shape']
            assert select_out_node[
                'value'] is None, "if 'out_shape' is defined manually 'value' must be None"

        flag = strict_compare_tensors(actual_val, ref_val)
        msg = '' if flag else 'reference {} and actual {} {} do not match\n'.format(
            ref_val, actual_val, value_desc)
        return flag, msg
Esempio n. 8
0
    def test_select_infer_assert_condition_bool(self):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('condition', {'value': np.array([3])})])

        tested_class = Select(graph=graph, attrs={})

        node = Node(graph, 'select')
        with self.assertRaisesRegex(AssertionError, "TensorFlow \'Select\' operation has 3 inputs: \'condition\',"
                                                    " \'then\' and \'else\' tensors. Value of \'condition\' tensor"
                                                    " must be boolen by TensorFlow reference"):
            tested_class.infer(node)
    def find_and_replace_pattern(self, graph: Graph):
        for merge in graph.get_op_nodes(op='Merge'):
            for merge_switch_in_port in range(2):
                if merge.in_port(merge_switch_in_port).disconnected() or \
                        merge.in_port(merge_switch_in_port).get_source().node.op != 'Switch':
                    continue
                switch_2 = merge.in_port(
                    merge_switch_in_port).get_source().node

                if merge.in_port(1 - merge_switch_in_port).disconnected() or \
                        merge.in_port(1 - merge_switch_in_port).get_source().node.op != 'Identity':
                    continue
                false_value_port = merge.in_port(
                    1 - merge_switch_in_port).get_source()

                true_value_port = switch_2.in_port(0).get_source()
                op = false_value_port.node.in_port(0).get_source().node

                if op.in_port(0).disconnected(
                ) or op.in_port(0).get_source().node.op != 'Switch':
                    continue
                switch = op.in_port(0).get_source().node

                if op.in_port(1).disconnected(
                ) or op.in_port(1).get_source().node.op != 'Switch':
                    continue
                switch_1 = op.in_port(1).get_source().node

                if switch.in_port(1).get_source() == switch_1.in_port(1).get_source() and \
                        switch.in_port(1).get_source() == switch_2.in_port(1).get_source():
                    select = Select(
                        graph,
                        dict(name=merge.soft_get('name') + '/Select/',
                             format='tf')).create_node()
                    select.in_port(0).connect(switch.in_port(1).get_source())
                    select.in_port(1).connect(true_value_port)
                    select.in_port(2).connect(false_value_port)

                    merge.out_port(0).get_connection().set_source(
                        select.out_port(0))

                    assert 1 in op.in_ports() and 0 in op.in_ports()

                    op.in_port(0).disconnect()
                    op.in_port(1).disconnect()

                    switch.in_port(0).get_connection().set_destination(
                        op.in_port(0))
                    switch_1.in_port(0).get_connection().set_destination(
                        op.in_port(1))

                    graph.remove_nodes_from(
                        nodes=[switch_1.id, switch.id, switch_2.id, merge.id])
                    # need to exit from the inner for loop because the Merge op has been removed
                    break
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Esempio n. 11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        slice_name = node.soft_get('name', node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

        eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node()
        minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node()
        int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node()
        select_node = Select(graph, {'name': slice_name + '/select'}).create_node()

        # node to convert sizes to ends
        sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node()

        # reconnect input from tfslice to slice
        node.in_port(0).get_source().connect(slice_node.in_port(0))
        node.in_port(0).disconnect()
        # reconnect begin of tfslice to start of slice
        node.in_port(1).get_source().connect(slice_node.in_port(1))
        node.in_port(1).disconnect()

        # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice
        # connects begins to slice
        slice_node.in_port(1).get_source().connect(sum_node.in_port(0))
        node.in_port(2).get_source().connect(sum_node.in_port(1))
        node.in_port(2).disconnect()

        # if size[i] == -1 when take int32_max as end[i]
        sum_node.in_port(1).get_source().connect(eq_node.in_port(0))
        minus_one_node.out_port(0).connect(eq_node.in_port(1))
        # from equal to 0 port of select
        eq_node.out_port(0).connect(select_node.in_port(0))
        # from int32_max to 1 of select
        int32_max_node.out_port(0).connect(select_node.in_port(1))
        # from sum to 2nd of select
        sum_node.out_port(0).connect(select_node.in_port(2))
        # out of select to end (2nd of slice)
        select_node.out_port(0).connect(slice_node.in_port(2))

        cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
        select_node.in_port(2).get_connection().insert_node(cast)

        node.out_port(0).get_connection().set_source(slice_node.out_port(0))
Esempio n. 12
0
    def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        # Zero Point Nudging : Scale counting
        f_min = node.in_port(1).get_source()
        node.in_port(1).disconnect()
        f_max = node.in_port(2).get_source()
        node.in_port(2).disconnect()

        f_diff = Sub(graph, {'name': name + '/float_range'}).create_node()
        f_max.connect(f_diff.in_port(0))
        f_min.connect(f_diff.in_port(1))

        quant_min_value = int(node.narrow_range)
        quant_max_value = 2 ** node.num_bits - 1
        i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        f_diff.out_port(0).connect(scale.in_port(0))
        i_diff.out_port(0).connect(scale.in_port(1))

        # Zero Point Nudging : ZP from min counting
        descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node()
        f_min.connect(descaled_min.in_port(0))
        scale.out_port(0).connect(descaled_min.in_port(1))

        zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node()
        quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node()
        quant_min.out_port(0).connect(zero_point_from_min.in_port(0))
        descaled_min.out_port(0).connect(zero_point_from_min.in_port(1))

        # Zero Point Nudging : Nudged Zero Point counting
        zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0))
        quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1))

        zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0))
        quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node()
        quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1))

        rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node()
        zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0))

        nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node()
        greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node()

        greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0))
        greater_condition.in_port(1).connect(quant_max.out_port(0))
        greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0))

        nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0))
        nudged_zero_point.in_port(1).connect(quant_max.out_port(0))
        nudged_zero_point.in_port(2).connect(greater_condition.out_port(0))

        nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node()
        quant_min.out_port(0).connect(nudged_i_min.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1))

        nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node()
        quant_max.out_port(0).connect(nudged_i_max.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1))

        nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node()
        nudged_i_min.out_port(0).connect(nudged_min.in_port(0))
        scale.out_port(0).connect(nudged_min.in_port(1))

        nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node()
        nudged_i_max.out_port(0).connect(nudged_max.in_port(0))
        scale.out_port(0).connect(nudged_max.in_port(1))

        nudged_min.out_port(0).connect(node.in_port(1))
        nudged_max.out_port(0).connect(node.in_port(2))

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(3).connect(nudged_min.out_port(0))
        node.in_port(4).connect(nudged_max.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
Esempio n. 13
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
                                                               ('mem_in_data', dict(shape=int64_array([context_len]))),
                                                               ('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
                                                                                    offset=int64_array([1]),
                                                                                    dim=int64_array([context_len - 1]))),
                                                               ('crop_mem_in_data', dict()),
                                                               ('concat', dict(op='Concat', axis=1)),
                                                               ('concat_data', dict()),
                                                               ('const_1', dict(op='Const')),
                                                               ('const_1_data', dict()),
                                                               ('mem_out', dict(op='Assign')),
                                                               ('crop_out', dict(op='Crop', axis=int64_array([1]),
                                                                                 offset=int64_array([0]),
                                                                                 dim=int64_array([1]))),
                                                               ('crop_out_data', dict()),
                                                               ('select', dict(op='Select'))
                                                               ],
                                                 edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                                                        ('crop_mem_in', 'crop_mem_in_data'),
                                                        ('crop_mem_in_data', 'concat', {'in': 0}),
                                                        ('const_1', 'const_1_data'),
                                                        ('const_1_data', 'concat', {'in': 1}),
                                                        ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                                                        ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                                                        ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(graph, {'name': 'iteration_number',
                                        'variable_id': 'iteration_' + node.name}).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
                                     'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
            concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(graph, {'name': 'iteration_number_out',
                                    'variable_id': 'iteration_' + node.name}).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
                                    'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Esempio n. 14
0
    def test_select_infer_condition_with_value(self, condition_shape,
                                               else_data_shape,
                                               than_data_shape,
                                               select_output_shape,
                                               condition_value, else_value,
                                               than_value, output_value):
        """
        Unit tests generator can sporadic throw exception if we try
        to run generator with call numpy array generation functions.
        So we need to use lambda function for escape the problem.
        """
        condition_value = condition_value(condition_shape)
        else_value = else_value(else_data_shape)
        than_value = than_value(than_data_shape)
        output_value = output_value(select_output_shape)
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                       edges_with_attrs=self.edges,
                                       update_nodes_attributes=[
                                           ('condition_data', {
                                               'shape':
                                               np.array(condition_shape),
                                               'value': condition_value
                                           }),
                                           ('else_data', {
                                               'shape':
                                               np.array(else_data_shape),
                                               'value': else_value
                                           }),
                                           ('than_data', {
                                               'shape':
                                               np.array(than_data_shape),
                                               'value': than_value
                                           }),
                                           ('select_output', {
                                               'shape':
                                               np.array(select_output_shape),
                                               'value':
                                               None
                                           })
                                       ])

        graph_ref = build_graph_with_attrs(
            nodes_with_attrs=self.nodes,
            edges_with_attrs=self.edges,
            update_nodes_attributes=[('condition_data', {
                'shape': np.array(condition_shape),
                'value': condition_value
            }),
                                     ('else_data', {
                                         'shape': np.array(else_data_shape),
                                         'value': else_value
                                     }),
                                     ('than_data', {
                                         'shape': np.array(than_data_shape),
                                         'value': than_value
                                     }),
                                     ('select_output', {
                                         'shape':
                                         np.array(select_output_shape),
                                         'value': output_value
                                     })])

        node = Node(graph, 'select')
        Select.infer(node)

        if else_value is not None and than_value is not None:
            (flag, resp) = compare_graphs(graph,
                                          graph_ref,
                                          'select_output',
                                          check_op_attrs=True)
            self.assertTrue(flag, resp)
        self.assertTrue(
            np.array_equal(graph.nodes['select_output']['value'],
                           graph_ref.nodes['select_output']['value']))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Esempio n. 16
0
 def extract(cls, node):
     Select.update_node_stat(node, {})
     return cls.enabled
Esempio n. 17
0
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
    def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        min_port_tuple = (node.in_port(1).get_source().node,
                          node.in_port(1).get_source().idx)
        max_port_tuple = (node.in_port(2).get_source().node,
                          node.in_port(2).get_source().idx)

        node.in_port(1).disconnect()
        node.in_port(2).disconnect()

        # make sure min < max
        min_less_max = Less(graph, {
            'name': name + '/if_min_less_max'
        }).create_node([min_port_tuple, max_port_tuple])
        minimum = Select(graph, {
            'name': name + '/minimum'
        }).create_node([min_less_max, min_port_tuple, max_port_tuple])
        maximum = Select(graph, {
            'name': name + '/maximum'
        }).create_node([min_less_max, max_port_tuple, min_port_tuple])

        # to create zero of limits data type, we multiply it by integer zero
        zero = create_op_node_with_second_input(graph,
                                                Mul,
                                                int64_array(0),
                                                {'name': name + '/zero'},
                                                input_node=minimum)

        # if 0 < min < max: min_adj = 0 and max_adj = max - min
        min_greater_zero = Greater(graph, {
            'name': name + '/if_minimum_greater_zero'
        }).create_node([minimum, zero])
        max_minus_min = Sub(graph, {
            'name': name + '/max_minus_min'
        }).create_node([maximum, minimum])
        minimum = Select(graph, {
            'name': name + '/first_adj_min'
        }).create_node([min_greater_zero, zero, minimum])
        maximum = Select(graph, {
            'name': name + '/first_adj_max'
        }).create_node([min_greater_zero, max_minus_min, maximum])

        # if min < max < 0: min_adj = min - max and max_adj = 0
        max_less_zero = Less(graph, {
            'name': name + '/if_max_less_zero'
        }).create_node([maximum, zero])
        min_minus_max = Sub(graph, {
            'name': name + '/min_minus_max'
        }).create_node([minimum, maximum])
        minimum = Select(graph, {
            'name': name + '/second_adj_min'
        }).create_node([max_less_zero, min_minus_max, minimum])
        maximum = Select(graph, {
            'name': name + '/second_adj_max'
        }).create_node([max_less_zero, zero, maximum])

        # scale = (max - min) / (2 ^ num_bits - 1),
        float_range = Sub(graph, {
            'name': name + '/float_range'
        }).create_node([maximum, minimum])
        quant_min_value, quant_max_value = int(
            node.narrow_range), 2**node.num_bits - 1
        int_range = Const(
            graph,
            dict(name=name + '/int_range',
                 value=quant_max_value - quant_min_value)).create_node()
        scale = Div(graph, {
            'name': name + '/scale'
        }).create_node([float_range, int_range])
        # min_adj = scale * round(min / scale)
        descaled_min = Div(graph, {
            'name': name + '/descaled_min'
        }).create_node([minimum, scale])
        rounded_descaled_min = Round(graph, {
            'name': name + '/rounded_descaled_min'
        }).create_node([descaled_min])
        min_adj = Mul(graph, {
            'name': name + '/min_adj'
        }).create_node([scale, rounded_descaled_min])
        # max_adj = max + min_adj - min.
        adjustment = Sub(graph, {
            'name': name + '/limits_adjustment'
        }).create_node([min_adj, minimum])
        max_adj = Add(graph, {
            'name': name + '/max_adj'
        }).create_node([maximum, adjustment])

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(1).connect(min_adj.out_port(0))
        node.in_port(2).connect(max_adj.out_port(0))
        node.in_port(3).connect(min_adj.out_port(0))
        node.in_port(4).connect(max_adj.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
Esempio n. 19
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in',
                    dict(op='Memory',
                         index=1,
                         shape=int64_array([context_len]))),
                   ('mem_in_data', dict()),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()),
                   ('mem_out',
                    dict(op='Memory',
                         index=0,
                         shape=int64_array([context_len]))),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            mem_out = Memory(
                graph, {
                    'name': 'iteration_number',
                    'size': 2,
                    'index': 1,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len]),
                    'dst_type': np.int32
                }).create_node()
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Memory(
                graph, {
                    'name': 'iteration_number_out',
                    'size': 2,
                    'index': 0,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len])
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        select_node.in_port(0).connect(input_port)
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Esempio n. 20
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(
                np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {
            'name': const_name + '/ShapeOf'
        }).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(
            graph, Equal, {1: int64_array(1)},
            {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {
            'name': node_name + '/Select',
            'auto_broadcast': 'numpy'
        }).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(
            graph, Range, {
                0: np.array(0, dtype=value.dtype),
                2: np.array(1, dtype=value.dtype)
            }, {
                'name': const_name + '/Range',
                'dtype': value.dtype
            })
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(
                graph, Unsqueeze, one_dims,
                {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (range_node, const_name)])
 def extract(cls, node: Node):
     Select.update_node_stat(node, {
         'format': 'tf',
     })
     return cls.enabled
Esempio n. 22
0
 def extract(node):
     Select.update_node_stat(node, {})
     return __class__.enabled
Esempio n. 23
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        seq_len_tf = match['seq_len']
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']

        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get(
            'name', sparse_to_dense_tf.id)
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
            'name', ctc_greedy_decoder_tf.id)

        log.debug(
            'Found CTCLossFrontReplacer pattern after {} with name {}'.format(
                ctc_greedy_decoder_tf.op, ctc_greedy_decoder_tf.name))

        # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
        seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
        if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
            raise Error(
                'The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
                ' must be specified in a mask format.'.format(
                    ctc_greedy_decoder_tf_name))
        log.error(
            'The format of input sequence length has been changed to a mask format',
            extra={'is_warning': True})
        seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
        seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
        seq_mask_placeholder = Parameter(
            graph, {
                'name': seq_len_tf_name,
                'shape': seq_len_tf_shape,
                'data_type': seq_len_tf_type
            }).create_node()
        reduce_to_seq_len_node = create_op_with_const_inputs(
            graph, ReduceSum, {1: np.array(1, dtype=np.int32)}, {
                'name': seq_len_tf_name + '/ReduceToSeqLen',
                'keep_dims': False
            })
        reduce_to_seq_len_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        seq_len_tf.out_port(0).get_connection().set_source(
            reduce_to_seq_len_node.out_port(0))

        cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
        casted_seq_mask_node = Cast(graph, {
            'name': seq_len_tf_name + '/CastToFP32',
            'dst_type': cast_fp_type
        }).create_node()
        casted_seq_mask_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        permuted_casted_seq_mask = create_op_with_const_inputs(
            graph, Transpose, {1: int64_array([1, 0])},
            {'name': seq_len_tf_name + '/Permute'})
        permuted_casted_seq_mask.in_port(0).connect(
            casted_seq_mask_node.out_port(0))
        rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'),
                      (seq_mask_placeholder, seq_len_tf_name)])

        # create CTCGreedyDecoder node and set mask node
        ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get(
            'ctc_merge_repeated', ctc_greedy_decoder_tf.id)
        ctc_greedy_decoder = CTCGreedyDecoderOp(
            graph, {
                'name': output_sparse_to_dense_name,
                'ctc_merge_repeated': ctc_merge_repeated_i
            }).create_node()
        ctc_greedy_decoder.in_port(1).connect(
            permuted_casted_seq_mask.out_port(0))
        rename_nodes([(sparse_to_dense_tf,
                       output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])

        # create CTCLoss node and set attributes
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(
            graph, {
                'name': output_ctc_loss_name,
                'preprocess_collapse_repeated': preprocess_collapse_repeated,
                'ctc_merge_repeated': ctc_merge_repeated,
                'unique': unique
            }).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'),
                      (ctc_loss, output_ctc_loss_name)])

        # connect logits
        ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(
            ctc_greedy_decoder.in_port(0))
        ctc_loss.in_port(0).disconnect()
        transpose_tf.in_port(0).get_connection().add_destination(
            ctc_loss.in_port(0))

        # connect logit lengths
        ctc_greedy_decoder_tf.in_port(1).disconnect()
        ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))

        # connect labels to ctc_loss
        squeeze_op = create_op_with_const_inputs(graph, Squeeze,
                                                 {1: int64_array([2, 3])})
        cast_labels_op = Cast(
            graph, {
                'name': output_sparse_to_dense_name + '/CastLabels',
                'dst_type': np.int32
            }).create_node()
        squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
        ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))

        # connect label lengths
        equal_op = create_op_with_const_inputs(
            graph, Equal, {1: np.array([-1], dtype=np.int32)},
            {'name': output_sparse_to_dense_name + '/Equal'})
        equal_op.in_port(0).connect(cast_labels_op.out_port(0))
        labels_shape_op = Shape(
            graph, {
                'name': output_sparse_to_dense_name + '/ShapeOf'
            }).create_node()
        labels_shape_op.in_port(0).connect(equal_op.out_port(0))
        broadcast_one = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([1], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/One'
            })
        broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
        broadcast_zero = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([0], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/Zero'
            })
        broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))

        select_node = Select(graph, {
            'name': output_sparse_to_dense_name + '/Select'
        }).create_node()
        select_node.in_port(0).connect(equal_op.out_port(0))
        select_node.in_port(1).connect(broadcast_zero.out_port(0))
        select_node.in_port(2).connect(broadcast_one.out_port(0))
        label_length_node = create_op_with_const_inputs(
            graph,
            ReduceSum, {1: int64_array([1])},
            op_attrs={
                'name': output_sparse_to_dense_name + '/LabelLength',
                'keep_dims': False
            })
        label_length_node.in_port(0).connect(select_node.out_port(0))
        ctc_loss.in_port(3).connect(label_length_node.out_port(0))

        # set source for output of new sub-graph and remove old nodes
        ctc_loss_tf.out_port(0).get_connection().set_source(
            ctc_loss.out_port(0))
        graph.remove_nodes_from([
            ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id,
            sparse_to_dense_tf.id
        ])
Esempio n. 24
0
 def extract(node: Node):
     Select.update_node_stat(node, {
         'format': 'tf',
     })
     return __class__.enabled
Esempio n. 25
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': np.array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])