Exemplo n.º 1
0
    def replace_pattern(self, graph: Graph, match: dict):
        conv = match['conv']
        stb = match['space_to_batch']
        bts = match['batch_to_space']

        block_size = match['stb_bs']

        input = match['input']
        output = match['output']
        stb_out = match['stb_output']
        conv_out = match['conv_output']

        in_edge_attrs = graph.get_edge_data(input.id, stb.id)[0]
        out_edge_attrs = graph.get_edge_data(bts.id, output.id)[0]

        graph.remove_edge(input.id, stb.id)
        graph.remove_edge(stb_out.id, conv.id)
        graph.remove_edge(conv.id, conv_out.id)
        graph.remove_edge(bts.id, output.id)

        conv.dilation[conv.spatial_dims] = block_size.value[conv.spatial_dims]

        pad_begin = match['stb_pad_begin'].value - match['bts_crop_begin'].value
        pad_end = match['stb_pad_end'].value - match['bts_crop_end'].value
        conv.pad[conv.spatial_dims] = [[pad_begin[x], pad_end[x]] for x in conv.spatial_dims]
        conv['auto_pad'] = None

        graph.add_edges_from([
            (input.id, conv.id, {'in': 0, **in_edge_attrs}),
            (conv.id, output.id, {'out': 0, **out_edge_attrs}),
        ])
def _insert_pooling(graph: Graph, first_node: Node, second_node: Node,
                    spatial_dims):
    """
    This function inserts point wise pooling layer between two nodes
    """
    log.debug("STRIDE PROP: Insert pooling between {} and {}".format(
        first_node.name, second_node.name))
    stride_prop = second_node.stride_prop
    assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1
    eattrs = graph.get_edge_data(first_node.id, second_node.id)[0]
    graph.remove_edge(first_node.id, second_node.id)

    pooling = Pooling(
        graph,
        dict(name='Pooling_',
             spatial_dims=spatial_dims,
             window=np.array([1, 1, 1, 1]),
             output_spatial_shape=None,
             stride=np.array(stride_prop),
             pad_spatial_shape=np.array([[0, 0], [0, 0]]),
             pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]),
             pool_method='max',
             is_partial_inferred=False))
    pooling_data = pooling.create_node_with_data([first_node])

    _clean_fw_tensor_attrs(pooling_data)

    graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
Exemplo n.º 3
0
    def reordering_inputs(graph: Graph, match: dict):
        """
        Reorder (renumbering) inputs to described format. We need to renumber initial states ports.
        """
        rnn_layer = match['rnn_layer']
        assert 5 in rnn_layer.in_nodes()
        hidden_state_edge = graph.get_edge_data(rnn_layer.in_node(5).id, rnn_layer.id)
        hidden_state_edge[0]['in'] = 4

        if rnn_layer.op == 'LSTM':
            assert 6 in rnn_layer.in_nodes()
            cell_state_edge = graph.get_edge_data(rnn_layer.in_node(6).id, rnn_layer.id)
            cell_state_edge[0]['in'] = 5
    def replace_sub_graph(self, graph: Graph, match: dict):
        """
        Need to find the pattern: SoftmaxActivation -> DetectionOutput
        DetectionOutput in IE expects flattened input from SoftMax, that is why there is the need to add
        Flatten layer

        Parameters
        ----------
        graph : Graph
           Graph with loaded model.
         match : dict
           Patterns which were found in graph structure.
        """
        softmax_activation = match['softmax_activation']
        multi_box_detection = match['multi_box_detection']
        softmax_activation['axis'] = -1
        edge_data = graph.get_edge_data(softmax_activation.id,
                                        multi_box_detection.id)
        out_port = edge_data[0]['out']
        in_port = edge_data[0]['in']
        graph.remove_edge(softmax_activation.id, multi_box_detection.id)
        new_reshape_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(op='Reshape', name=multi_box_detection.name + '/Reshape_'),
            softmax_activation)
        graph.create_edge(new_reshape_node,
                          multi_box_detection,
                          in_port=in_port,
                          out_port=out_port)
Exemplo n.º 5
0
    def permute_data_nodes_attrs(graph: Graph):
        # Iterate over all data nodes and apply permutation if exists
        for node in graph.get_data_nodes():
            if not node.has_valid('permutation') or \
                    all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]):
                continue

            if len(
                    node.in_nodes()
            ) != 0:  # there are data nodes without input operation node inside the TensorIterator
                edge_attrs = graph.get_edge_data(node.in_node(0).id,
                                                 node.id)[0]
                if is_output_data_in_correct_layout(node.in_node(0),
                                                    edge_attrs['out']):
                    log.debug(
                        'Do not permute data node attrs for node "{}" output port "{}"'
                        .format(node.in_node(0).id, edge_attrs['out']))
                    continue

            # Apply permutation for shape and value if exists
            if len(node.permutation.perm) == 0:
                continue
            node.shape = shape_array(node.shape)[node.permutation.perm]
            if node.has_valid('value'):
                assert len(node.value.shape) == len(node.permutation.perm), \
                    'Node {} has shape {} and permutation {} that does not match. Their lengths should be equal' \
                    ''.format(node.name, node.value.shape, node.permutation.perm)
                node.value = mo_array(
                    node.value.transpose(node.permutation.perm))
Exemplo n.º 6
0
    def add_reshape_after_data_node(graph: Graph, data_node_name: str):
        """
        Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension
        of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the
        add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only.
        :param graph: graph to operate on.
        :param data_node_name: name of the data node to be reshaped to correct dimensions.
        :return: None
        """
        data_node = Node(graph, data_node_name)

        # if the data node was previously marked as output then we need to mark as output new reshaped data node
        is_out_node = False
        if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'Result':
            is_out_node = True
            graph.remove_node(data_node.out_node().id)

        # save old consumers nodes with edge attributes
        old_consumer_nodes_with_attrs = list()
        for index, out_op in enumerate(data_node.out_nodes()):
            edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0]
            old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs))

        # remove old consumers from the data node
        for out_op in list(data_node.out_nodes()):
            graph.remove_edge(data_node_name, out_op.name)

        # reshape operation node
        reshape_node_name = graph.unique_id("Reshape_")
        graph.add_node(reshape_node_name, kind='op', type='Reshape', name=reshape_node_name, op='Reshape',
                       data_type=data_node['data_type'])
        update_ie_fields(graph.node[reshape_node_name])

        # reshape shape data node
        reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
        graph.add_node(reshape_shape_data_node_name, kind='data', name=reshape_shape_data_node_name,
                       value=np.array(data_node['shape']), shape=[1])

        # reshaped data node
        reshaped_value = None
        if data_node['value'] is not None:
            reshaped_value = np.array(data_node['value'])
        reshaped_data_node_name = graph.unique_id("reshaped_data_")
        graph.add_node(reshaped_data_node_name, kind='data', name=reshaped_data_node_name,
                       shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True)

        if is_out_node:
            add_opoutput(graph, reshaped_data_node_name, 0, False)

        graph.add_edges_from([
            (data_node_name, reshape_node_name, {'in': 0}),
            (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
            (reshape_node_name, reshaped_data_node_name, {'out': 0}),
        ])

        for out_node_name, edge_attrs in old_consumer_nodes_with_attrs:
            graph.add_edges_from([
                (reshaped_data_node_name, out_node_name, edge_attrs)
            ])
Exemplo n.º 7
0
    def find_and_replace_pattern(self, graph: Graph):
        mp = {}
        used = {}
        for node in graph.get_op_nodes(type='Concat'):
            in_nodes = tuple(
                [node.in_node(idx).id for idx in range(len(node.in_nodes()))])
            out_node = (node.id, node.out_node().id)
            if in_nodes in mp:
                log.warning("Something is weird! {} and {}".format(
                    node.id, mp[in_nodes]))
            else:
                mp.update({in_nodes: out_node})
                used.update({node.id: {x: False for x in in_nodes}})

        for key in mp.keys():
            replacers = []
            for i in range(len(key)):
                for j in range(i + 1, len(key)):
                    arr = tuple(key[i:j + 1])
                    if arr in mp.keys() and arr != key:
                        replacers.append((len(arr), arr))

            replacers.sort(reverse=True)

            concat_id = mp[key][0]
            for ln, arr in replacers:
                # Check that we can do it!!!
                we_can = True
                for x in arr:
                    if used[concat_id][x]:
                        we_can = False
                        break

                if not we_can:
                    continue

                for x in arr:
                    used[concat_id][x] = True

                edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
                for in_node in arr:
                    graph.remove_edge(in_node, concat_id)

                new_input = mp[arr][1]
                out_port = len(Node(graph, new_input).out_nodes()) + 1
                edge_attrs['out'] = out_port
                graph.add_edge(new_input, concat_id, **edge_attrs)

                # Renumber 'in' attrs
                concat_node = Node(graph, concat_id)
                ln = len(concat_node.in_nodes())
                ports = [x for x in concat_node.in_nodes().keys()]
                ports.sort()

                p_id = 0
                for p in ports:
                    in_node = concat_node.in_nodes()[p]
                    graph[in_node.id][concat_id][0]['in'] = p_id
                    p_id += 1
    def check_init_states(graph: Graph, match: dict):
        """
        Check if cell have initial states and create zeros states if not.
        And renumber ports for this states.
        """
        rnn_cell = match['rnn_layer']
        num_directions = 2 if rnn_cell.direction == 'bidirectional' else 1
        batch_size = rnn_cell.in_node(0).shape[rnn_cell.batch_dim]

        h_init_port = 5
        c_init_port = 6

        if 2 not in rnn_cell.in_nodes():
            h_shape = [num_directions, batch_size,
                       rnn_cell.hidden_size]  # from ONNX spec
            h_init = np.full(h_shape, 0, dtype=np.float32)
            Op.create_and_connect_input_data_node(
                graph, rnn_cell, {
                    'value': h_init,
                    'shape': int64_array(h_init.shape)
                }, {
                    'in': h_init_port,
                    'permutation': None
                })
        else:
            hidden_state_edge = graph.get_edge_data(
                rnn_cell.in_node(2).id, rnn_cell.id)
            hidden_state_edge[0]['in'] = h_init_port

        if rnn_cell.op == 'LSTM':
            if 3 not in rnn_cell.in_nodes():
                c_shape = [num_directions, batch_size,
                           rnn_cell.hidden_size]  # from ONNX spec
                c_init = np.full(c_shape, 0, dtype=np.float32)
                Op.create_and_connect_input_data_node(
                    graph, rnn_cell, {
                        'value': c_init,
                        'shape': int64_array(c_init.shape)
                    }, {
                        'in': c_init_port,
                        'permutation': None
                    })
            else:
                cell_state_edge = graph.get_edge_data(
                    rnn_cell.in_node(3).id, rnn_cell.id)
                cell_state_edge[0]['in'] = c_init_port
Exemplo n.º 9
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['softmax']
     if 'temperature' in node and node['temperature'] != 1.0:
         in_node = node.in_node()
         out_nodes = [node for node in node.out_nodes().values()]
         graph.remove_edge(node.in_node().id, node.id)
         temperature = mo_array([1.0 / node.temperature])
         scalar_value_op = Const(graph, dict(value=temperature, shape=temperature.shape,
                                             symbol_dict={'name': node.id + '/const'}))
         mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'}))
         mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()])
         edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0]
         graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
    def replace_pattern(graph: Graph, match: dict):
        """
        DetectionOutput layer has another order of inputs unlike mxnet.
        Need to reorder _contrib_MultiBoxDetection inputs
        for correct conversion to DetectionOutput layer.

        Parameters
        ----------
        graph : Graph
           Graph with loaded model.
        """
        multi_box_detection_node = match['multi_box_detection']
        conf_node = multi_box_detection_node.in_node(0)
        loc_node = multi_box_detection_node.in_node(1)

        conf_edge_data = graph.get_edge_data(conf_node.id,
                                             multi_box_detection_node.id)
        conf_out_port = conf_edge_data[0]['out']
        conf_in_port = conf_edge_data[0]['in']

        loc_edge_data = graph.get_edge_data(loc_node.id,
                                            multi_box_detection_node.id)
        loc_out_port = loc_edge_data[0]['out']
        loc_in_port = loc_edge_data[0]['in']

        graph.remove_edge(conf_node.id, multi_box_detection_node.id)
        graph.remove_edge(loc_node.id, multi_box_detection_node.id)

        graph.create_edge(loc_node,
                          multi_box_detection_node,
                          in_port=conf_in_port,
                          out_port=conf_out_port)
        graph.create_edge(conf_node,
                          multi_box_detection_node,
                          in_port=loc_in_port,
                          out_port=loc_out_port)
Exemplo n.º 11
0
def pad_op_transform(graph: Graph, match: dict):
    op = match['op']
    pad_op = match['pad_op']
    input_data = pad_op.in_node(0)

    if pad_op.mode != 'constant':
        log.info(
            'The pad node "{}" with pad mode "{}" cannot be fused.'.format(
                pad_op.soft_get('name'), pad_op.mode))
        return

    if op.type == 'Pooling' and op.pool_method == 'max':
        return

    if pad_op.mode == 'constant':
        fill_value = pad_op.in_port(3).data.get_value()
        if fill_value is None or fill_value != 0.0:
            log.info(
                'The pad node "{}" with non-zero fill value cannot be fused.'.
                format(pad_op.soft_get('name')))
            return

    input_tensor_dims = len(match['pad_output'].shape)
    for in_port in [1, 2]:
        pads = pad_op.in_port(in_port).data.get_value()
        if pads[get_features_dim(op.graph.graph['layout'], input_tensor_dims)] != 0 or \
                pads[get_batch_dim(op.graph.graph['layout'], input_tensor_dims)] != 0:
            log.info(
                'The pad node "{}" with padding over feature/batch dimension cannot be fused.'
                .format(pad_op.soft_get('name')))
            return

    op.pad += np.concatenate([
        pad_op.in_port(1).data.get_value().reshape([-1, 1]),
        pad_op.in_port(2).data.get_value().reshape([-1, 1])
    ],
                             axis=1)
    op.pad_spatial_shape = op.pad[op.spatial_dims]
    op['auto_pad'] = None
    if op.type == 'Pooling':
        op['exclude_pad'] = False
    assert (graph[match['pad_output'].node][match['op'].node][0]['in'] == 0)
    edge_attrs = graph.get_edge_data(match['pad_output'].id, match['op'].id)[0]
    graph.remove_edge(match['pad_output'].id, match['op'].id)
    graph.add_edge(input_data.id, match['op'].id, **{'in': 0, **edge_attrs})
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if not node.has_port('in', 2) or node.in_port(2).disconnected() or not node.has_and_set('shape_input'):
            return

        if node.has_valid('layout') and not node.layout.startswith('NC') and graph.graph['layout'] == 'NCHW':
            input_shape_rank = len(node.in_port(0).data.get_shape())
            permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(input_shape_rank)

            data_node = node.in_node(2)

            name = node.soft_get('name', node.id) + '/ShapeGather'
            const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
                                  'need_shape_inference': True}).create_node_with_data()
            axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
            gather = Gather(graph, {'name': name,
                                    'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
            attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

            graph.add_edge(gather.id, node.id, **attrs)
            graph.remove_edge(data_node.id, node.id)
Exemplo n.º 13
0
    def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        for input_data in list(graph.get_data_nodes()):
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            if input_data.shape is None:
                continue
            input_shape = shape_array(input_data.shape)

            # Get all unique StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and
                         node.in_node(0).id == input_data.id]

            if len(out_nodes) <= 1:
                continue

            valid_for_replacement = True
            for n in out_nodes:
                if any(not isinstance(s, slice) for s in n.slices):
                    # this is a slice with dynamic dimension. Such operation is not valid for replacement
                    valid_for_replacement = False
            if not valid_for_replacement:
                continue

            sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices))
            out_nodes = unique_by(sorted_out_nodes, strided_slices_equality)

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                # if both l and r are None then the dimension is not sliced
                if (l != 0 or r != input_shape[dim_id]) and (l is not None or r is not None):
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            if split_channel_dim is None:
                valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                prev_r = r

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False

            if not valid_for_replacement:
                continue

            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Save missing tensor part
                if l > prev_r:
                    shape = mo_array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r < input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                final_data_nodes_list.append(data_node)

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_squeeze_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_unsqueeze_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            name = name_for_future_split + "/Split"
            axis_const = Const(graph, {'value': int64_array(split_channel_dim),
                                       'name': name + '/Axis'}).create_node_with_data()
            size_splits_const = Const(graph, {'value': int64_array(size_splits),
                                              'name': name + '/Sizes'}).create_node_with_data()
            split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))

            split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
                                        data_nodes=final_data_nodes_list)
Exemplo n.º 14
0
def build_graph_with_attrs(nodes_with_attrs: list,
                           edges_with_attrs: list,
                           new_nodes_with_attrs: list = [],
                           new_edges_with_attrs: list = [],
                           update_edge_attrs: dict = None,
                           update_nodes_attributes: list = None,
                           nodes_with_edges_only: bool = False,
                           add_nodes_from_edges: bool = False):
    """
    Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported.
    :param nodes_with_attrs: list of tuples ('node_name', {node_attrs})
    :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}).
    :param new_nodes_with_attrs: analogically nodes_with_attrs
    :param new_edges_with_attrs: analogically new_edges
    :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}.
    :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be
    updated. The first element is a node name to update attribute and the second element is a dictionary with attribute
    name and its value.
    :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
    :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed.
    :return: generated graph.
    """
    if not_all_new([node[0] for node in nodes_with_attrs],
                   [node[0] for node in new_nodes_with_attrs]):
        raise Error(
            'Some nodes from new_nodes_with_attrs are already in nodes.'
            ' Please, add to new_nodes_with_attrs only NEW nodes.')

    if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs],
                   [(edge[0], edge[1]) for edge in new_edges_with_attrs]):
        raise Error(
            'Some edges from new_edges_with_attrs are already in edges.'
            ' Please, add to new_edges_with_attrs only NEW edges.')

    # Check that all nodes from list of edges are in nodes
    all_nodes = nodes_with_attrs + new_nodes_with_attrs
    all_edges = edges_with_attrs + new_edges_with_attrs
    all_nodes_names = [node[0] for node in all_nodes]
    if not add_nodes_from_edges and not all_edges_in_nodes(
            nodes=all_nodes_names, edges=all_edges):
        raise Error(
            "Some nodes from list of edges is not in nodes. Please, add all necessary nodes."
        )

    graph = Graph()

    # Create dict for nodes with attrs
    nodes_attrs = {}
    for node_name, attrs in all_nodes:
        nodes_attrs[node_name] = attrs
        if 'name' not in attrs:
            attrs['name'] = node_name

    if nodes_with_edges_only:
        # filter nodes to keep only ones with edges connected
        filtered_nodes = {}
        for edge in all_edges:
            node_1, node_2 = edge[0], edge[1]
            filtered_nodes[node_1] = nodes_attrs[node_1]
            filtered_nodes[node_2] = nodes_attrs[node_2]
        nodes_attrs = filtered_nodes

    # Create all nodes
    for node, attrs in nodes_attrs.items():
        graph.add_node(node, **deepcopy(attrs))

    # Connect nodes with edges (also unpack edge params)
    for edge in all_edges:
        node_1, node_2 = edge[0], edge[1]
        edge_attrs = edge[2] if len(edge) == 3 else {}
        graph.add_edge(node_1, node_2, **edge_attrs)

    # Update attributes of edges
    if update_edge_attrs:
        # it will work in 2.x networkx only
        for edge, attr in update_edge_attrs.items():
            for k, v in attr.items():
                nx.set_edge_attributes(G=graph, name=k, values={edge: v})

    # Update attributes of nodes
    if update_nodes_attributes is not None:
        for node_name, new_attrs in update_nodes_attributes:
            assert (node_name in graph.nodes())
            for attr, value in new_attrs.items():
                graph.node[node_name][attr] = value

    for node_id in graph.nodes():
        node = Node(graph, node_id)
        check_and_update_ports(node, [
            graph.get_edge_data(edge[0], node_id)[0]
            for edge in graph.in_edges(node_id)
        ], True)
        check_and_update_ports(node, [
            graph.get_edge_data(node_id, edge[1])[0]
            for edge in graph.out_edges(node_id)
        ], False)

    for node in graph.get_op_nodes():
        # Add in_ports attribute
        in_edges = node.in_edges()
        for i in range(len(in_edges)):
            node.add_input_port(idx=i)

        # Add out_ports attribute
        out_edges = node.out_edges()
        for i in range(len(out_edges)):
            node.add_output_port(idx=i)
    return graph
Exemplo n.º 15
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            node_name = node.soft_get('name', node.id)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_permute_name = node_name + '/input_transpose'
                input_order_const = Const(
                    graph, {
                        'name': input_permute_name + '/order',
                        'value': permutation.perm
                    }).create_node_with_data()
                input_permute_op = Transpose(graph,
                                             {'name': input_permute_name})
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_permute_name = node_name + '/output_transpose'
                output_order_const = Const(
                    graph, {
                        'name': output_permute_name + '/order',
                        'value': permutation.inv
                    }).create_node_with_data()
                output_permute_op = Transpose(graph, {
                    'name': output_permute_name
                }).create_node_with_data([input_data_node, output_order_const],
                                         data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None
Exemplo n.º 16
0
    def replace_pattern(graph: Graph, match: dict):
        time_len = match['concatenated_hidden_states'].shape[0]
        r"""
        Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
        concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM

        We simplify the sub-graph below by taking another output of BlockLSTM:
        concatenated cell states over the whole time sequence -> last cell state

        BlockLSTM
           || out 1 (concatenated cell states coming out of BlockLSTM)
           \/  in 1
        ConcatV2
           || (concatenation with initial state or another unused data)
           \/
        Reshape
           ||
           \/
         Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len)
        """
        # check that there are no other consumers of concatenated_cell_states_data data flow
        valid_output_names = [
            'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data',
            'gather_1', 'gather_1_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data'
        ]
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    raise Error(
                        "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                        "time sequence. It is not replaceable by another output and is not supported "
                        "originally".format(match['BlockLSTM'].id))

        # check that we really take the last cell state data by Gather
        gather_indexes = match['gather_1'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        if gather_index != time_len:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        """
        We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern 
        to LSTMSequence even without following optimizations
        """

        node = match['BlockLSTM']
        weights_node = node.in_node(1)
        biases_node = node.in_node(2)
        shift_const = node.forget_bias

        # Assign temporary shape for them for easier manipulation
        # TF stores weights in IO order
        input_size = node.in_node(0).shape[-1]
        hidden_size = node.in_node(3).shape[-1]
        weights = weights_node.value
        biases = biases_node.value
        assert weights.shape[0] == input_size + hidden_size, \
            "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
            "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)

        weights = weights.reshape([
            weights.shape[0],
            4,  # gates
            hidden_size
        ])

        biases = biases.reshape([
            4,  # gates
            hidden_size
        ])

        # Reorder gates icfo --> fico for both weights and biases
        gate_reorder = [2, 0, 1, 3]
        weights = np.take(weights, gate_reorder, axis=1)
        biases = np.take(biases, gate_reorder, axis=0)

        # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
        # Note: in case of moving this code up before gate reordering, the addition
        # should be applied at different place
        biases[0] += shift_const

        # Return to the original shapes
        weights = weights.reshape([weights.shape[0], -1])
        biases = biases.flatten()

        # TF stores weights in IO, but IE requires it in OI: transpose
        weights = weights.transpose()

        weights_node.value = weights
        weights_node.shape = int64_array(weights.shape)
        biases_node.value = biases
        biases_node.shape = int64_array(biases.shape)

        attrs = dict(
            graph.get_edge_data(match['gather_1'].id,
                                match['gather_1_data'].id)[0])
        attrs.update({'out': 2})
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_cell_states_data'].id)
        graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id)

        match['BlockLSTM'].add_output_port(attrs['out'])
        graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id,
                       **attrs)
        """
        #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
        """
        h_init_port = 4
        c_init_port = 5
        # c_init_state
        if 4 in node.in_nodes():
            assert c_init_port not in node.in_nodes()
            cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
            cell_state_edge[0]['in'] = c_init_port

        #h_init_state
        if 3 in node.in_nodes():
            assert h_init_port not in node.in_nodes()
            hidden_state_edge = graph.get_edge_data(
                node.in_node(3).id, node.id)
            hidden_state_edge[0]['in'] = h_init_port

        new_attrs = {
            'sequence_dim': 0,
            'batch_dim': 1,
            'direction': 'forward',
            'hidden_size': match['concatenated_hidden_states'].shape[-1],
            'format': 'tf',
        }

        LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
        """
        Optional #4 optimization from class description following
        """
        data_to_mul = [
            n for n in match['mul'].in_nodes().values()
            if n.id != match['concatenated_hidden_states'].id
        ]
        if len(data_to_mul) != 1:
            return  # unexpected type of mul
        data_to_mul = data_to_mul[0]
        if not data_to_mul.has_valid('value'):
            return  # unexpected type of mul
        data_to_mul_value = data_to_mul.value
        if not np.all(data_to_mul_value == 1):
            return  # unexpected type of mul

        # remove useless mul
        attrs = dict(
            graph.get_edge_data(match['BlockLSTM'].id,
                                match['concatenated_hidden_states'].id)[0])
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_hidden_states'].id)
        graph.remove_edge(match['mul'].id, match['mul_data'].id)
        graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs)

        # find true usages of concatenated hidden states data (not last hidden state)
        valid_output_names = [
            'mul_data', 'concat_0', 'concat_0_data', 'reshape_0',
            'reshape_0_data', 'gather_0', 'gather_0_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'mul_data', 'concat_0_data', 'reshape_0_data'
        ]

        list_of_concatenated_hidden_states_children_node_ids = []
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    list_of_concatenated_hidden_states_children_node_ids.append(
                        node.id)

        if len(list_of_concatenated_hidden_states_children_node_ids) != 1:
            return  # not supported placement of pattern
        conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[
            0]
        if conacenated_child_node_id != match[
                'after_mul_op_to_the_rest_of_model'].id:
            return  # not supported placement of pattern

        gather_indexes = match['gather_0'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
        if gather_index != time_len:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is

        attrs = dict(
            graph.get_edge_data(match['gather_0'].id,
                                match['gather_0_data'].id)[0])
        attrs.update({'out': 1})
        graph.remove_edge(match['mul_data'].id, match['concat_0'].id)
        graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id)

        graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id,
                       **attrs)