def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if pair_node.has_default:
            return

        if node.in_port(0).get_source() is not None:
            input_node_out_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_node_in_ports = pair_node.out_port(0).get_destinations()
        else:
            input_node_out_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_node_in_ports = node.out_port(0).get_destinations()

        in_shape = input_node_out_port.data.get_shape().copy()

        node_id = node.id
        node_name = node.name
        node_t = node.t

        splice = Splice(graph, {'name': node_name,
                                'id': node_id,
                                'context': int64_array(range(node_t, 1))
                                if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
        splice.in_port(0).connect(input_node_out_port)

        # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
        crop = Crop(graph, {'name': 'Splice_Crop',
                            'axis': int64_array([1]),
                            'offset': int64_array([max(0, in_shape[1] * node_t)]),
                            'dim': int64_array([in_shape[1]])}).create_node()

        splice.out_port(0).connect(crop.in_port(0))
        splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))

        outs = input_node_out_port.get_destinations()
        for in_port in outs:
            out_ = in_port.node
            if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
                crop_input = Crop(graph, {'name': 'Splice_Crop',
                                          'axis': int64_array([1]),
                                          'offset': int64_array([-min(0, in_shape[1] * node_t)]),
                                          'dim': int64_array([in_shape[1]])}).create_node()
                splice.out_port(0).connect(crop_input.in_port(0))

                in_port.disconnect()
                crop_input.out_port(0).connect(in_port)
                crop_input.out_port(0).data.set_shape(in_shape)

        for dest_port in out_node_in_ports:
            dest_port.connect(crop.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
                                       'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
                                    'offset': np.array([0]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Exemplo n.º 3
0
def create_zero_value_with_batch_from_input(input_out_port: Port,
                                            second_dim,
                                            precision=np.float):
    # create init_graph connected to ReadValue
    graph = input_out_port.node.graph
    input_name = input_out_port.node.name
    shape_of_input = Shape(graph, {
        'name': 'shape/' + input_name
    }).create_node()
    shape_of_input.in_port(0).connect(input_out_port)
    dim_for_get_batch = Const(
        graph, {
            'name': 'dim/crop_batch/' + shape_of_input.name,
            'value': int64_array([1]),
            'shape': int64_array([1])
        }).create_node()
    get_batch = Crop(
        graph, {
            'name': 'crop_batch/' + shape_of_input.name,
            'axis': int64_array([0]),
            'offset': int64_array([0])
        }).create_node()
    get_batch.in_port(0).connect(shape_of_input.out_port(0))
    get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
    mem_shape_2nd_dim = Const(
        graph, {
            'name': 'gifo_r_weights_shape/' + input_name,
            'value': int64_array([second_dim]),
            'shape': int64_array([1])
        }).create_node()
    mem_shape = Concat(
        graph, {
            'name': 'gather_memory_shape/' + input_name,
            'axis': 0,
            'in_ports_count': 2
        }).create_node()
    mem_shape.in_port(0).connect(get_batch.out_port(0))
    mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
    fill_value = Const(
        graph, {
            'name': 'fill_value/' + input_name,
            'value': np.array([0.0], precision),
            'shape': int64_array([1])
        }).create_node()
    init_value_prev_lstm_output = Broadcast(graph, {
        'name': 'init_value/' + input_name,
    }).create_node()
    init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
    init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
    return init_value_prev_lstm_output
Exemplo n.º 4
0
    def replace_pattern(graph: Graph, match: dict):
        mem = match['op']
        mem_shape = mem.in_port(0).data.get_shape()
        mem_parent = mem.in_port(0).get_source()
        context = mem['context']

        for child_port in mem_parent.get_destinations():
            child = child_port.node
            # check if we find Splice containing context 'context'
            if child['op'] == 'Splice' and child.id != mem.id and set(
                    child['context']).issubset(set(context)):
                left_cont_out = child['context'][0]
                left_cont = context[0]

                for child_of_child in child.out_port(0).get_destinations():
                    out_transfer = child_of_child.node
                    out_transfer_port = child_of_child
                    if out_transfer['op'] == 'Crop':
                        # modify existing Crop to get right data from larger Splice
                        out_transfer['offset'] = out_transfer['offset'] + (
                            left_cont_out - left_cont) * mem_shape[-1]
                    else:
                        # insert Crop if we have not one
                        child_of_child.disconnect()
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                (left_cont_out - left_cont) * mem_shape[-1],
                                'dim':
                                np.array(
                                    [len(child['context']) * mem_shape[-1]]),
                                'axis':
                                np.array([-1])
                            }).create_node()
                        child.out_port(0).connect(crop_node.in_port(0))
                        crop_node.out_port(0).connect(child_of_child)
                        crop_node.out_port(0).data.set_shape(
                            child.out_port(0).data.get_shape())

                        out_transfer_port = crop_node.in_port(0)

                    # move edge to child from old Splice to larger
                    out_transfer_port.disconnect()
                    mem.out_port(0).connect(out_transfer_port)

                graph.remove_node(child.id)
Exemplo n.º 5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        slice_like = match['slice_like']
        const = slice_like.in_nodes()[0]
        crop_shape = slice_like.in_nodes()[1]

        variants_dict = {
            'mul_scalar1x': 0.1,
            'mul_scalar2x': 0.2,
            'mul_scalar1y': 0.1,
            'mul_scalar2y': 0.2
        }
        for matches in find_pattern_matches(graph,
                                            self.variants_pattern['nodes'],
                                            self.variants_pattern['edges'],
                                            None, None):
            for k, v in matches.items():
                if v in variants_dict.keys():
                    variants_dict[v] = Node(graph, k).in_nodes()[1].value[0]

        variants = np.array([
            variants_dict['mul_scalar1x'], variants_dict['mul_scalar1y'],
            variants_dict['mul_scalar2x'], variants_dict['mul_scalar2y']
        ] * int(const.value.size / 4)).reshape(const.value.shape)
        priorbox_variants = Const(
            graph,
            dict(value=variants,
                 symbol_dict={'name':
                              const.id + '/priorbox_variants'})).create_node()
        variants_slice_like = Crop(graph, dict(axis=slice_like.axis, offset=slice_like.offset, dim=slice_like.dim, axes=slice_like.axes,
                                               symbol_dict={'name': slice_like.id + '/variants_slice_like'})) \
            .create_node()
        variants_slice_like.in_port(0).connect(priorbox_variants.out_port(0))
        variants_slice_like.in_port(1).connect(crop_shape.out_port(0))

        concat = match['reshape3'].out_port(0).get_destination().node
        assert concat.op == 'Concat'
        concat_nodes_count = len(concat.in_nodes())
        concat.add_input_port(concat_nodes_count)
        concat.in_port(concat_nodes_count).get_connection().set_source(
            variants_slice_like.out_port(0))
Exemplo n.º 6
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        node_id = node['variable_id']

        out_node_port = node.out_port(0).get_destination()
        in_node_port = node.in_port(0).get_source()
        node.in_port(0).disconnect()
        node.out_port(0).disconnect()
        crop = Crop(
            graph, {
                'name': 'Result_for_' + node_id,
                'dim': np.array([1]),
                'offset': np.array([0]),
                'axis': np.array([0])
            }).create_node()
        in_node_port.connect(crop.in_port(0))
        crop.out_port(0).connect(out_node_port)
Exemplo n.º 7
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        node_id = node['id']

        if node.in_port(0).disconnected():
            i = 0
            for dest in node.out_port(0).get_destinations():
                new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
                                           'shape': dest.data.get_shape()}).create_node()
                i += 1
                dest.disconnect()
                new_in.out_port(0).connect(dest)
                log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
                          extra={'is_warning': True})
        else:
            out_node_port = node.out_port(0).get_destination()
            in_node_port = node.in_port(0).get_source()
            node.in_port(0).disconnect()
            node.out_port(0).disconnect()
            crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
            in_node_port.connect(crop.in_port(0))
            crop.out_port(0).connect(out_node_port)
Exemplo n.º 8
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['slice']

        input = node.in_node(0)
        output_data = node.out_node()

        # ONNX 10 opset case
        if len(node.in_nodes()) >= 3 and node.has_valid(
                'format') and node['format'] == 'onnx':
            self.convert_onnx_slice_opset10(node)
            return

        # Caffe case
        if not node.has_valid('start') or not node.has_valid('end'):
            return

        begin = node.start
        end = node.end
        axis = node.axis if node.has_valid('axis') else np.arange(begin.size)

        # Check whether operation use only one axis or not
        axes_begin = np.zeros(len(input.shape), dtype=np.int32)
        axes_end = np.zeros(len(input.shape), dtype=np.int32)
        ss_begin = np.zeros(len(input.shape), dtype=np.int32)
        ss_end = np.zeros(len(input.shape), dtype=np.int32)
        dims = 0
        axes = np.zeros(begin.size)
        for i in range(len(axis)):
            if begin[i] != 0 or end[i] < input.shape[axis[i]]:
                dims += 1
                axes[i] = 1
                if begin[i] != 0:
                    axes_begin[axis[i]] = 1
                    ss_begin[axis[i]] = begin[i]
                if end[i] < input.shape[axis[i]]:
                    axes_end[axis[i]] = 1
                    ss_end[axis[i]] = end[i]
        axes = np.array(axes, dtype=bool)

        slice_node_name = node.soft_get('name', node.id)

        if dims == 1 or dims == 0:
            # If Slice use only one axis or no axis, than
            # convert Slice to StridedSlice
            ss = StridedSlice(
                graph,
                dict(new_axis_mask=np.zeros(len(output_data.shape),
                                            dtype=np.int32),
                     shrink_axis_mask=np.zeros(len(output_data.shape),
                                               dtype=np.int32),
                     ellipsis_mask=np.zeros(len(output_data.shape),
                                            dtype=np.int32),
                     begin_mask=axes_begin,
                     end_mask=axes_end)).create_node()

            convert_negative_indices(ss_begin, input.shape)
            convert_negative_indices(ss_end, input.shape)

            begin_node = Const(graph, {
                'value': ss_begin,
                'name': slice_node_name + '/begin'
            }).create_node()
            end_node = Const(graph, {
                'value': ss_end,
                'name': slice_node_name + '/end'
            }).create_node()

            rename_nodes([(node, slice_node_name + '_delete'),
                          (ss, slice_node_name)])

            node.in_port(0).get_connection().set_destination(ss.in_port(0))
            begin_node.out_port(0).connect(ss.in_port(1))
            end_node.out_port(0).connect(ss.in_port(2))
            node.out_port(0).get_connection().set_source(ss.out_port(0))
        else:
            # If Slice use more than one axis use Crop layer
            crop = Crop(
                graph,
                dict(axis=axis[axes],
                     offset=begin[axes],
                     dim=end[axes] - begin[axes])).create_node()
            rename_nodes([(node, slice_node_name + '_delete'),
                          (crop, slice_node_name)])

            node.in_port(0).get_connection().set_destination(crop.in_port(0))
            node.out_port(0).get_connection().set_source(crop.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(
            input_port, in_shape[1] * node_t)
        memory_out = ReadValue(graph, {
            'name': pair_name,
            'variable_id': node_name + pair_name
        }).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(
                graph, {
                    'name': 'Memory_crop',
                    'dim': np.array([in_shape[1] * (node_t - 1)]),
                    'offset': np.array([in_shape[1]]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(
                graph, {
                    'name': 'Memory_crop_out',
                    'dim': np.array([in_shape[1]]),
                    'offset': np.array([0]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        if not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method",
                extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Exemplo n.º 10
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        in_shape = node.in_port(0).data.get_shape().copy()
        memory_element = in_shape[1] - node.const_dim
        memory_size = memory_element * len(node.context)

        memory_pair_id = unique_id('id')
        # Memory(in)
        input_memory = Memory(
            graph, {
                'name': 'prev_splice_memory',
                'id': memory_pair_id,
                'index': 1,
                'size': 2,
                'shape': int64_array([memory_size])
            }).create_node()
        # Memory(in)  \
        #             Crop
        # Input(temp) /
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([memory_element]),
                'dim': int64_array([memory_size - memory_element])
            }).create_node()
        crop.in_port(0).connect(input_memory.out_port(0))

        # Crop   \
        #         Concat
        # Input  /
        concat_node = Concat(graph, {
            'name': 'Splice_Concat',
            'in_ports_count': 2,
            'axis': 1
        }).create_node()
        concat_node.in_port(0).connect(crop.out_port(0))

        # Concat -> Memory(out)
        mem_out = Memory(
            graph, {
                'name': 'out_splice_memory',
                'id': memory_pair_id,
                'index': 0,
                'size': 2,
                'shape': int64_array([memory_size])
            }).create_node()
        mem_out.in_port(0).connect(concat_node.out_port(0))
        Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))

        if node.const_dim != 0:
            memory_element_constdim = node.const_dim
            memory_size_constdim = memory_element_constdim * len(node.context)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: int64_array(1),
                    2: int64_array([memory_element, memory_element_constdim])
                }, {
                    'name': node.id + '_split_const',
                    'out_ports_count': 2
                })

            split.out_port(0).connect(concat_node.in_port(1))

            # create separate splice construction for const_dim
            memory_pair_id = unique_id('memory_for_const_dim')
            input_memory_const_dim = Memory(
                graph, {
                    'name': 'const_dim_in_memory',
                    'id': memory_pair_id,
                    'index': 1,
                    'size': 2,
                    'shape': int64_array([memory_size_constdim])
                }).create_node()
            crop_const_dim = Crop(
                graph, {
                    'name':
                    'const_dim_crop',
                    'axis':
                    int64_array([1]),
                    'offset':
                    int64_array([memory_element_constdim]),
                    'dim':
                    int64_array(
                        [memory_size_constdim - memory_element_constdim])
                }).create_node()
            crop_const_dim.in_port(0).connect(
                input_memory_const_dim.out_port(0))

            concat_node_const_dim = Concat(graph, {
                'name': 'const_dim_concat',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat_node_const_dim.in_port(0).connect(
                crop_const_dim.out_port(0))

            mem_out_const_dim = Memory(
                graph, {
                    'name': 'const_dim_out_memory',
                    'id': memory_pair_id,
                    'index': 0,
                    'size': 2,
                    'shape': int64_array([memory_size_constdim])
                }).create_node()
            mem_out_const_dim.in_port(0).connect(
                concat_node_const_dim.out_port(0))
            Result(graph).create_node().in_port(0).connect(
                mem_out_const_dim.out_port(0))

            # connect splice to Split as begin and Concat as the end
            split.out_port(1).connect(concat_node_const_dim.in_port(1))
            crop_first = Crop(
                graph, {
                    'name': 'const_dim_crop_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([memory_element_constdim])
                }).create_node()
            crop_first.in_port(0).connect(concat_node_const_dim.out_port(0))

            concat_const = Concat(graph, {
                'name': node.id + '_concat_const',
                'axis': 1,
                'in_ports_count': 2
            }).create_node()
            concat_const.in_port(1).connect(crop_first.out_port(0))
            concat_const.in_port(0).connect(concat_node.out_port(0))

            node.in_port(0).get_connection().set_destination(split.in_port(0))
            node.out_port(0).get_connection().set_source(
                concat_const.out_port(0))
        else:
            node.in_port(0).get_connection().set_destination(
                concat_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                concat_node.out_port(0))

        # to avoid re-inference of shape and touching in next replacements
        graph.remove_node(node.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Exemplo n.º 12
0
    def replace_pattern(graph: Graph, match: dict):
        mem = match['op']
        mem_shape = mem.in_port(0).data.get_shape()
        mem_parent = mem.in_port(0).get_source()
        context = mem['context']

        for child_port in mem_parent.get_destinations():
            child = child_port.node
            if child['op'] == 'Splice' and child.id != mem.id and \
               (child['context'][0] == context[-1] or child['context'][0] == context[-1]):

                new_context = list(context)
                new_context.extend(list(child['context']))
                new_context = list(set(new_context))
                new_context.sort()
                if child['context'][0] == context[-1]:
                    new_node = mem
                    rem_node = child
                else:
                    new_node = child
                    rem_node = mem

                # reset edges from rem_node to new_node
                for out_port_rem in rem_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    out_port_rem.disconnect()

                    if out_transfer['op'] == 'Crop':
                        # modify existing Crop to get right data from larger Splice
                        out_transfer['offset'] = out_transfer['offset'] + (
                            len(new_context) -
                            len(rem_node.context)) * mem_shape[-1]
                        out_port_rem.connect(new_node.out_port(0))
                    else:
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                (len(new_context) - len(rem_node.context)) *
                                mem_shape[-1],
                                'dim':
                                np.array([
                                    len(rem_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                np.array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                for out_port_rem in new_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    if out_transfer['op'] != 'Crop':
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                np.array([0]),
                                'dim':
                                np.array([
                                    len(new_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                np.array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        out_port_rem.disconnect()
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                new_shape = new_node.out_port(0).data.get_shape()
                new_shape[1] += rem_node.out_port(0).data.get_shape(
                )[1] - rem_node.in_port(0).data.get_shape()[1]
                new_node.out_port(0).data.set_shape(new_shape)
                new_node.context = new_context

                graph.remove_node(rem_node.id)
Exemplo n.º 13
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in',
                    dict(op='Memory',
                         index=1,
                         shape=int64_array([context_len]))),
                   ('mem_in_data', dict()),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()),
                   ('mem_out',
                    dict(op='Memory',
                         index=0,
                         shape=int64_array([context_len]))),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            mem_out = Memory(
                graph, {
                    'name': 'iteration_number',
                    'size': 2,
                    'index': 1,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len]),
                    'dst_type': np.int32
                }).create_node()
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Memory(
                graph, {
                    'name': 'iteration_number_out',
                    'size': 2,
                    'index': 0,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len])
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        select_node.in_port(0).connect(input_port)
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Exemplo n.º 14
0
    def replace_pattern(self, graph: Graph, match: dict):
        """
        Converts specific for NasNet topology subgraph Pad->StridedSlice->AvgPool to Conv->Crop->AvgPool
        """
        input = match['input']

        pad_node = match['pad_op']
        pad_node_name = pad_node.soft_get('name', pad_node.id)

        sslice_node = match['sslice']
        begin = []
        end = []
        stride = []
        for s in sslice_node.slices:
            begin.append(s.start)
            end.append(s.stop)
            stride.append(s.step)

        pads_begin = pad_node.in_port(1).data.get_value()
        pads_end = pad_node.in_port(2).data.get_value()
        if pads_begin is None or pads_end is None:
            log.error('Pad values for node "{}" are not constants'.format(
                pad_node_name))
            return

        if not np.array_equal(pads_begin, int64_array([0, 0, 0, 0])):
            log.error('Pad begin values doesn\'t match for node {}!'.format(
                pad_node_name))
            return

        if not np.array_equal(pads_end, int64_array([0, 1, 1, 0])):
            log.error('Pad end values doesn\'t match for node {}!'.format(
                pad_node_name))
            return

        if not np.array_equal(begin, int64_array([0, 1, 1, 0])):
            log.error("StridedSlice has wrong begin")
            return

        if not np.array_equal(sslice_node.end_mask, int64_array(
            [0, 0, 0, 0])) or not np.array_equal(sslice_node.begin_mask,
                                                 int64_array([0, 1, 1, 0])):
            log.error("StridedSlice has wrong masks")
            return

        # Pad -> Conv
        conv_name = graph.unique_id(pad_node.name + '/Conv_')
        conv_weights_name = graph.unique_id(pad_node.name + '/ConvW_')
        conv_weights = np.ones((input.shape[3], 1, 1, 1))
        output_shape = int64_array([
            input.shape[0], input.shape[1] + 1, input.shape[2] + 1,
            input.shape[3]
        ])

        conv_node = Convolution(
            graph,
            dict(
                name=conv_name,
                stride=int64_array([1, 1, 1, 1]),
                dilation=int64_array([1, 1, 1, 1]),
                group=input.shape[3],
                bias_addable=True,
                bias_term=False,
                spatial_dims=int64_array([1, 2]),
                kernel_spatial=int64_array([1, 1]),
                pad=int64_array([[0, 0], [0, 1], [0, 1], [0, 0]]),
                output_shape=output_shape,
                batch_dims=int64_array([0]),
                channel_dims=int64_array([3]),
                output=input.shape[3],
                input_feature_channel=1,
                output_feature_channel=0,
            )).create_node()

        weights_const_node = Const(
            graph,
            dict(name=conv_weights_name,
                 value=conv_weights,
                 shape=int64_array(conv_weights.shape))).create_node()

        # StridedSlice -> Crop
        crop_node = Crop(
            graph,
            dict(name=sslice_node.name + '/Crop_',
                 axis=int64_array([1, 2]),
                 dim=int64_array([output_shape[1] - 1, output_shape[2] - 1]),
                 offset=int64_array([1, 1]))).create_node()

        # Connect nodes
        pad_node.in_port(0).get_connection().set_destination(
            conv_node.in_port(0))
        weights_const_node.out_port(0).connect(conv_node.in_port(1))
        conv_node.out_port(0).connect(crop_node.in_port(0))
        sslice_node.out_port(0).get_connection().set_source(
            crop_node.out_port(0))

        conv_node.in_port(1).bin = 'weights'

        # Remove Pad and StridedSlice nodes from graph
        graph.remove_node(pad_node.id)
        graph.remove_node(sslice_node.id)
Exemplo n.º 15
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
                                                               ('mem_in_data', dict(shape=int64_array([context_len]))),
                                                               ('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
                                                                                    offset=int64_array([1]),
                                                                                    dim=int64_array([context_len - 1]))),
                                                               ('crop_mem_in_data', dict()),
                                                               ('concat', dict(op='Concat', axis=1)),
                                                               ('concat_data', dict()),
                                                               ('const_1', dict(op='Const')),
                                                               ('const_1_data', dict()),
                                                               ('mem_out', dict(op='Assign')),
                                                               ('crop_out', dict(op='Crop', axis=int64_array([1]),
                                                                                 offset=int64_array([0]),
                                                                                 dim=int64_array([1]))),
                                                               ('crop_out_data', dict()),
                                                               ('select', dict(op='Select'))
                                                               ],
                                                 edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                                                        ('crop_mem_in', 'crop_mem_in_data'),
                                                        ('crop_mem_in_data', 'concat', {'in': 0}),
                                                        ('const_1', 'const_1_data'),
                                                        ('const_1_data', 'concat', {'in': 1}),
                                                        ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                                                        ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                                                        ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(graph, {'name': 'iteration_number',
                                        'variable_id': 'iteration_' + node.name}).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
                                     'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
            concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(graph, {'name': 'iteration_number_out',
                                    'variable_id': 'iteration_' + node.name}).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
                                    'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)