def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Beispiel #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        slice_name = node.soft_get('name', node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

        eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node()
        minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node()
        int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node()
        select_node = Select(graph, {'name': slice_name + '/select'}).create_node()

        # node to convert sizes to ends
        sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node()

        # reconnect input from tfslice to slice
        node.in_port(0).get_source().connect(slice_node.in_port(0))
        node.in_port(0).disconnect()
        # reconnect begin of tfslice to start of slice
        node.in_port(1).get_source().connect(slice_node.in_port(1))
        node.in_port(1).disconnect()

        # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice
        # connects begins to slice
        slice_node.in_port(1).get_source().connect(sum_node.in_port(0))
        node.in_port(2).get_source().connect(sum_node.in_port(1))
        node.in_port(2).disconnect()

        # if size[i] == -1 when take int32_max as end[i]
        sum_node.in_port(1).get_source().connect(eq_node.in_port(0))
        minus_one_node.out_port(0).connect(eq_node.in_port(1))
        # from equal to 0 port of select
        eq_node.out_port(0).connect(select_node.in_port(0))
        # from int32_max to 1 of select
        int32_max_node.out_port(0).connect(select_node.in_port(1))
        # from sum to 2nd of select
        sum_node.out_port(0).connect(select_node.in_port(2))
        # out of select to end (2nd of slice)
        select_node.out_port(0).connect(slice_node.in_port(2))

        cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
        select_node.in_port(2).get_connection().insert_node(cast)

        node.out_port(0).get_connection().set_source(slice_node.out_port(0))