Esempio n. 1
0
    def replace_pattern(graph: Graph, match: dict):
        mem = match['op']
        mem_shape = mem.in_port(0).data.get_shape()
        mem_parent = mem.in_port(0).get_source()
        context = mem['context']

        for child_port in mem_parent.get_destinations():
            child = child_port.node
            if child['op'] == 'Splice' and child.id != mem.id and \
               (child['context'][0] == context[-1] or child['context'][0] == context[-1]):

                new_context = list(context)
                new_context.extend(list(child['context']))
                new_context = list(set(new_context))
                new_context.sort()
                if child['context'][0] == context[-1]:
                    new_node = mem
                    rem_node = child
                else:
                    new_node = child
                    rem_node = mem

                # reset edges from rem_node to new_node
                for out_port_rem in rem_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    out_port_rem.disconnect()

                    if out_transfer['op'] == 'Crop':
                        # modify existing Crop to get right data from larger Splice
                        out_transfer['offset'] = out_transfer['offset'] + (
                            len(new_context) -
                            len(rem_node.context)) * mem_shape[-1]
                        out_port_rem.connect(new_node.out_port(0))
                    else:
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                (len(new_context) - len(rem_node.context)) *
                                mem_shape[-1],
                                'dim':
                                np.array([
                                    len(rem_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                np.array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                for out_port_rem in new_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    if out_transfer['op'] != 'Crop':
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                np.array([0]),
                                'dim':
                                np.array([
                                    len(new_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                np.array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        out_port_rem.disconnect()
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                new_shape = new_node.out_port(0).data.get_shape()
                new_shape[1] += rem_node.out_port(0).data.get_shape(
                )[1] - rem_node.in_port(0).data.get_shape()[1]
                new_node.out_port(0).data.set_shape(new_shape)
                new_node.context = new_context

                graph.remove_node(rem_node.id)
Esempio n. 2
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port,
                                                       in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(
                in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1,
                                                      np.int32)
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if pair_node.has_default:
            return

        if node.in_port(0).get_source() is not None:
            input_node_out_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_node_in_ports = pair_node.out_port(0).get_destinations()
        else:
            input_node_out_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_node_in_ports = node.out_port(0).get_destinations()

        in_shape = input_node_out_port.data.get_shape().copy()

        node_id = node.id
        node_name = node.name
        node_t = node.t

        splice = Splice(
            graph, {
                'name':
                node_name,
                'id':
                node_id,
                'context':
                int64_array(range(node_t, 1)) if node_t < 0 else int64_array(
                    range(0, node_t + 1))
            }).create_node()
        splice.in_port(0).connect(input_node_out_port)

        # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([max(0, in_shape[1] * node_t)]),
                'dim': int64_array([in_shape[1]])
            }).create_node()

        splice.out_port(0).connect(crop.in_port(0))
        splice.out_port(0).data.set_shape(
            int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))

        outs = input_node_out_port.get_destinations()
        for in_port in outs:
            out_ = in_port.node
            if out_.op == 'Concat' and out_ == out_node_in_ports[0].node:
                crop_input = Crop(
                    graph, {
                        'name': 'Splice_Crop',
                        'axis': int64_array([1]),
                        'offset': int64_array([-min(0, in_shape[1] * node_t)]),
                        'dim': int64_array([in_shape[1]])
                    }).create_node()
                splice.out_port(0).connect(crop_input.in_port(0))

                in_port.disconnect()
                crop_input.out_port(0).connect(in_port)
                crop_input.out_port(0).data.set_shape(in_shape)

        for dest_port in out_node_in_ports:
            dest_port.connect(crop.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        in_shape = node.in_port(0).data.get_shape().copy()
        memory_element = in_shape[1] - node.const_dim
        memory_size = memory_element * len(node.context)

        memory_pair_id = unique_id('id')
        # Memory(in)
        input_memory = ReadValue(graph, {
            'name': 'prev_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()

        # Memory(in)  \
        #             Crop
        # Input(temp) /
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([memory_element]),
                'dim': int64_array([memory_size - memory_element])
            }).create_node()
        crop.in_port(0).connect(input_memory.out_port(0))

        # Crop   \
        #         Concat
        # Input  /
        concat_node = Concat(graph, {
            'name': 'Splice_Concat',
            'in_ports_count': 2,
            'axis': 1
        }).create_node()
        concat_node.in_port(0).connect(crop.out_port(0))

        # Concat -> Memory(out)
        mem_out = Assign(graph, {
            'name': 'out_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()
        mem_out.in_port(0).connect(concat_node.out_port(0))
        Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))

        if node.const_dim != 0:
            memory_element_constdim = node.const_dim
            memory_size_constdim = memory_element_constdim * len(node.context)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: int64_array(1),
                    2: int64_array([memory_element, memory_element_constdim])
                }, {
                    'name': node.id + '_split_const',
                    'out_ports_count': 2
                })

            split.out_port(0).connect(concat_node.in_port(1))

            # create separate splice construction for const_dim
            memory_pair_id = unique_id('memory_for_const_dim')
            init_value_input_memory_const_dim = Const(
                graph, {
                    'name':
                    'init_value_const_dim_in_memory',
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size_constdim])),
                    'shape':
                    int64_array([in_shape[0], memory_size_constdim])
                }).create_node()
            input_memory_const_dim = ReadValue(graph, {
                'name': 'const_dim_in_memory',
                'variable_id': memory_pair_id
            }).create_node()
            init_value_input_memory_const_dim.out_port(0).connect(
                input_memory_const_dim.in_port(0))

            crop_const_dim = Crop(
                graph, {
                    'name':
                    'const_dim_crop',
                    'axis':
                    int64_array([1]),
                    'offset':
                    int64_array([memory_element_constdim]),
                    'dim':
                    int64_array(
                        [memory_size_constdim - memory_element_constdim])
                }).create_node()
            crop_const_dim.in_port(0).connect(
                input_memory_const_dim.out_port(0))

            concat_node_const_dim = Concat(graph, {
                'name': 'const_dim_concat',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat_node_const_dim.in_port(0).connect(
                crop_const_dim.out_port(0))

            mem_out_const_dim = Assign(graph, {
                'name': 'const_dim_out_memory',
                'variable_id': memory_pair_id
            }).create_node()
            mem_out_const_dim.in_port(0).connect(
                concat_node_const_dim.out_port(0))
            Result(graph).create_node().in_port(0).connect(
                mem_out_const_dim.out_port(0))

            # connect splice to Split as begin and Concat as the end
            split.out_port(1).connect(concat_node_const_dim.in_port(1))
            crop_first = Crop(
                graph, {
                    'name': 'const_dim_crop_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([memory_element_constdim])
                }).create_node()
            crop_first.in_port(0).connect(concat_node_const_dim.out_port(0))

            concat_const = Concat(graph, {
                'name': node.id + '_concat_const',
                'axis': 1,
                'in_ports_count': 2
            }).create_node()
            concat_const.in_port(1).connect(crop_first.out_port(0))
            concat_const.in_port(0).connect(concat_node.out_port(0))

            init_value_input_memory = Const(
                graph, {
                    'name': 'init_value_' + node.name,
                    'value': np.zeros(int64_array([in_shape[0], memory_size])),
                    'shape': int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(split.in_port(0))
            node.out_port(0).get_connection().set_source(
                concat_const.out_port(0))
        else:
            init_value_input_memory = Const(
                graph, {
                    'name': 'init_value_' + node.name,
                    'value': np.zeros(int64_array([in_shape[0], memory_size])),
                    'shape': int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(
                concat_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                concat_node.out_port(0))

        # to avoid re-inference of shape and touching in next replacements
        graph.remove_node(node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = Const(
            graph, {
                'name': 'init_value_' + pair_name,
                'value': np.zeros(
                    int64_array([in_shape[0], in_shape[1] * node_t])),
                'shape': int64_array([in_shape[0], in_shape[1] * node_t])
            }).create_node()
        memory_out = ReadValue(graph, {
            'name': pair_name,
            'variable_id': node_name + pair_name
        }).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(
                graph, {
                    'name': 'Memory_crop',
                    'dim': np.array([in_shape[1] * (node_t - 1)]),
                    'offset': np.array([in_shape[1]]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(
                graph, {
                    'name': 'Memory_crop_out',
                    'dim': np.array([in_shape[1]]),
                    'offset': np.array([0]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)