Ejemplo n.º 1
0
    def merge_infer(node: Node):
        # we infer only through executable input nodes
        inferred_nodes = [
            n for n in node.in_nodes().values() if n['is_partial_inferred']
        ]
        assert len(inferred_nodes) != 0
        tensor = inferred_nodes[0]

        if len(inferred_nodes) < len(node.in_nodes()):
            node['is_not_fully_inferred'] = True
        else:
            node['is_not_fully_inferred'] = False
            assert np.all(
                compatible_shapes(node.shape, inferred_nodes[0].shape)
                for node in inferred_nodes)

            inferred_and_executable = [
                n for n in node.in_nodes().values() if n['is_partial_inferred']
                and 'executable' in n and n['executable']
            ]
            if len(inferred_and_executable) > 0:
                tensor = inferred_and_executable[0]

                if all([
                        tensor.has_valid('value') and n.has_valid('value')
                        and strict_compare_tensors(tensor.value, n.value)
                        for n in inferred_and_executable
                ]):
                    node.out_node().value = tensor.value.copy()
                else:
                    node.out_node().value = None

        # do not use set_shape(tensor.shape) here because input port shape may be different from the calculated output
        # shape and `set_shape` will raise an error that shape has changed
        node.out_node(0).shape = shape_array(tensor.shape)
Ejemplo n.º 2
0
    def infer(node: Node):
        if node.has_and_set('extra_inputs'):
            assert len(node.in_nodes()) == 8
        else:
            assert len(node.in_nodes()) == 5
        assert len(node.out_nodes()) in [1, 2]

        hidden_shape = node.in_node(1).shape.copy()
        cell_shape = node.in_node(2).shape.copy()

        mark_input_bins(node, start_port=3)
        node.out_node(0).shape = hidden_shape
        if len(node.out_nodes()) == 2:
            node.out_node(1).shape = cell_shape

        hidden_size = hidden_shape[1]

        if node.has_valid('hidden_size'):
            if node.hidden_size != hidden_size:
                raise Error(
                    "Input shape {} for hidden size doesn't match pre-defined hidden_size in node {}"
                    .format(node.in_node(1).shape, node.soft_get('name')))
        else:
            node['hidden_size'] = hidden_size

        assert cell_shape[1] == hidden_size

        input_shape = node.in_node(0).shape
        assert input_shape is not None
        assert compatible_dims(hidden_shape[0], cell_shape[0]) and \
               compatible_dims(cell_shape[0], input_shape[0]), 'States are not broadcast-able by batch for node {}' \
                                                               ''.format(node.soft_get('name', node.id))
Ejemplo n.º 3
0
def roipooling_infer(node: Node):
    """
    Sets shape of output node according specified parameters input blobs and node
    Sets number from the first input blob, channels from the second one, height and width are specified
    Parameters
    ----------
    node
    """
    shapes = [node.in_node(i).shape for i in range(len(node.in_nodes()))]
    if any(s is None for s in shapes):
        return
    if len(node.in_nodes()) == 4:  # TensorFlow case of CropAndResize operation
        crop_size = node.in_node(3).value
        if crop_size is None:
            log.error('The ROIPooling size is not known for node {}'.format(
                node.soft_get('name')))
            return
        if not isinstance(crop_size, np.ndarray) or len(crop_size) != 2:
            log.error(
                'The ROIPooling size is should have 2 elements for node {}'.
                format(node.soft_get('name')))
        node.pooled_h = crop_size[0]
        node.pooled_w = crop_size[1]
        node.graph.remove_edge(node.in_node(3).id, node.id)
        node.graph.remove_edge(node.in_node(2).id, node.id)

    layout = node.graph.graph['layout']
    assert len(layout) == 4

    node.out_port(0).data.set_shape(
        shape_for_layout(layout,
                         batch=shapes[1][get_batch_dim(layout, 4)],
                         features=shapes[0][get_features_dim(layout, 4)],
                         height=node.pooled_h,
                         width=node.pooled_w))
Ejemplo n.º 4
0
    def find_and_replace_pattern(self, graph: Graph):
        mp = {}
        used = {}
        for node in graph.get_op_nodes(type='Concat'):
            in_nodes = tuple(
                [node.in_node(idx).id for idx in range(len(node.in_nodes()))])
            out_node = (node.id, node.out_node().id)
            if in_nodes in mp:
                log.warning("Something is weird! {} and {}".format(
                    node.id, mp[in_nodes]))
            else:
                mp.update({in_nodes: out_node})
                used.update({node.id: {x: False for x in in_nodes}})

        for key in mp.keys():
            replacers = []
            for i in range(len(key)):
                for j in range(i + 1, len(key)):
                    arr = tuple(key[i:j + 1])
                    if arr in mp.keys() and arr != key:
                        replacers.append((len(arr), arr))

            replacers.sort(reverse=True)

            concat_id = mp[key][0]
            for ln, arr in replacers:
                # Check that we can do it!!!
                we_can = True
                for x in arr:
                    if used[concat_id][x]:
                        we_can = False
                        break

                if not we_can:
                    continue

                for x in arr:
                    used[concat_id][x] = True

                edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
                for in_node in arr:
                    graph.remove_edge(in_node, concat_id)

                new_input = mp[arr][1]
                out_port = len(Node(graph, new_input).out_nodes()) + 1
                edge_attrs['out'] = out_port
                graph.add_edge(new_input, concat_id, **edge_attrs)

                # Renumber 'in' attrs
                concat_node = Node(graph, concat_id)
                ln = len(concat_node.in_nodes())
                ports = [x for x in concat_node.in_nodes().keys()]
                ports.sort()

                p_id = 0
                for p in ports:
                    in_node = concat_node.in_nodes()[p]
                    graph[in_node.id][concat_id][0]['in'] = p_id
                    p_id += 1
Ejemplo n.º 5
0
    def infer(node: Node):
        # there are limitations coming from ONNX LSTM definition and normalization rules
        assert len(node.in_nodes()) >= 3  # X, W and R
        assert len(node.in_nodes()) <= 7
        assert len(node.out_nodes()) <= 3
        assert node.batch_dim <= 1
        assert node.sequence_dim <= 1
        assert node.batch_dim != node.sequence_dim

        assert node.direction in ['forward', 'reverse', 'bidirectional']

        if node.blobs_wrb:
            mark_input_bins(node, ['W', 'R', 'B'])
        else:
            mark_input_bins(node)
        input_shape = node.in_node(0).shape
        assert len(input_shape) == 3

        for port in [2, 3]:
            if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
               'zero_shapes' in node.in_node(port).in_node():
                for i in node.in_node(port).in_node().zero_shapes:
                    if node.in_node(port).shape[i] != input_shape[i]:
                        node.in_node(port).value = np.repeat(
                            node.in_node(port).value, input_shape[i], axis=i)
                        node.in_node(port).shape[i] = input_shape[i]

        out_shape = shape_array([
            input_shape[node.sequence_dim], input_shape[node.batch_dim],
            node.hidden_size
        ])
        assert not node.has_num_directions or node.sequence_dim == 0, \
            'If has_num_directions == True, then node.sequence_dim should be equal 0, but it is {}'.format(
                node.sequence_dim)
        num_directions = 2 if node.direction in ['bidirectional'] else 1
        num_layers = node.num_layers
        if node.has_num_directions:
            # insert extra dimension to output shape for num_directions
            out_shape = shape_insert(out_shape, 1, np.int64(num_directions))
        node.out_node(0).shape = out_shape
        # extra outputs for hidden/cell states
        state_size = shape_array([input_shape[1], node.hidden_size])
        if node.has_num_directions:
            state_size = shape_insert(state_size, 0,
                                      num_directions * num_layers)
        for i in [1, 2]:
            if i not in node.out_nodes():
                data_node = Op._create_data_node(node.graph,
                                                 name=node.node +
                                                 '/ExtraOutput/' + str(i),
                                                 attrs={'executable': True})
                node.graph.add_edge(node.id, data_node.id, key=0, out=i)
                add_opoutput(node.graph, data_node.id, 0, False)
            else:
                data_node = node.out_node(i)
            data_node.shape = state_size.copy()
Ejemplo n.º 6
0
    def test_partial_infer(self):
        graph = build_graph(nodes_attributes, [('node_1', 'concat'),
                                               ('node_2', 'concat'),
                                               ('concat', 'node_3'),
                                               ('node_3', 'op_output')],
                            {
                                'node_3': {
                                    'kind': 'data',
                                    'shape': None,
                                    'infer': None
                                },
                                'node_1': {
                                    'kind': 'data',
                                    'shape': np.array([1, 3, 227, 227]),
                                    'infer': None
                                },
                                'node_2': {
                                    'kind': 'data',
                                    'shape': np.array([1, 3, 227, 227]),
                                    'infer': None
                                },
                                'concat': {
                                    'kind': 'op',
                                    'axis': 2,
                                    'infer': concat_infer
                                }
                            },
                            nodes_with_edges_only=True)

        start_node = 'concat'
        partial_infer(graph, start_node)
        node = Node(graph, start_node)
        self.assertTrue(node.is_partial_inferred)
        self.assertTrue(node.out_node().is_partial_inferred)

        # check if previous nodes are not inferred
        node = Node(graph, start_node)
        while True:
            # collect nodes in a list
            if isinstance(node.in_nodes(), list):
                in_nodes = node.in_nodes()
            else:
                in_nodes = [y for x, y in node.in_nodes().items()]

            # check parents and find next parent
            for n in in_nodes:
                if 'embedded_input_' not in n.id:
                    node = n
                self.assertFalse(n.has('is_partial_inferred'))

            if not len(in_nodes):
                break
Ejemplo n.º 7
0
    def infer(node: Node):
        assert len(node.in_nodes()) == len(__class__.inputs) + len(
            __class__.extra_inputs)

        for axis in ['concat_axis', 'split_axis']:
            axis_node = __class__.extra_inputs.index(axis) + len(
                __class__.inputs)
            assert node.in_node(axis_node).has_valid('value')
            assert node.in_node(axis_node).value == 1

        shift_const = node.in_node(
            __class__.extra_inputs.index('shift_const') +
            len(__class__.inputs))
        assert shift_const.has_valid('value')
        shift_const = shift_const.value
        assert shift_const.ndim == 0  # expect scalar value
        node['shift_const'] = shift_const.copy()

        weights_node = node.in_node(__class__.inputs.index('weights'))
        biases_node = node.in_node(__class__.inputs.index('biases'))

        assert weights_node.has_valid('value')
        assert biases_node.has_valid('value')

        # Restore original infer function (to avoid calling previous code twice) and call it
        node.infer = node.old_infer
        node.infer(node)
Ejemplo n.º 8
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.with_right_bound is not None, \
            "Attribute \"with_right_bound\" is not defined"
        assert len(node.in_nodes()) == 2, \
            "Incorrect number of inputs for {} node".format(node.id)
        if node.get_opset() != "extension":
            assert node.has_valid('output_type'), \
                '`output_type` attribute is not set for Bucketize node `{}`'.format(node_name)
            assert node.output_type in [np.int64, np.int32], \
                'Bucketize `output_type` attribute must be int32 or int64, `{}` found'.format(np.dtype(node.output_type).name)

        output_shape = node.in_port(0).data.get_shape()
        node.out_port(0).data.set_shape(output_shape)

        input_value = node.in_port(0).data.get_value()
        buckets_value = node.in_port(1).data.get_value()

        # compute if all input is constant
        if input_value is not None and buckets_value is not None:
            node.out_port(0).data.set_value(
                mo_array(np.digitize(input_value,
                                     buckets_value,
                                     right=node.with_right_bound),
                         dtype=node.output_type))
Ejemplo n.º 9
0
    def infer(node: Node):
        assert len(node.in_nodes()) == 4

        # check that shape value is defined that is needed for shape inference
        shape = node.in_node(2)
        assert shape.value is not None and shape.value.size == 2, \
            "SparseFillEmptyRows is supported only with constant shape value"

        shape_value = int64_array(shape.value)

        # check that default value is scalar
        default_value = node.in_node(3)
        assert default_value.shape is not None and len(default_value.shape) == 0, \
            "Default value for SparseFillEmptyRows must be scalar"

        if node.is_out_port_connected(0):  # set a shape for output indices
            if is_fully_defined(shape_value):
                node.out_port(0).data.set_shape([np.prod(shape_value), 2])
            else:
                node.out_port(0).data.set_shape([dynamic_dimension_value, 2])
        if node.is_out_port_connected(1):  # set a shape for output values
            if is_fully_defined(shape_value):
                node.out_port(1).data.set_shape([np.prod(shape_value)])
            else:
                node.out_port(1).data.set_shape([dynamic_dimension_value])
        if node.is_out_port_connected(
                2):  # set a shape for empty row indicator
            node.out_port(2).data.set_shape([shape_value[0]])
Ejemplo n.º 10
0
    def test_remove_softmax_activation_input(self):
        graph = build_graph(
            {
                'node_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op',
                    'op': 'Parameter'
                },
                'softmax': {
                    'type': 'SoftmaxActivation',
                    'value': None,
                    'kind': 'op',
                    'op': 'SoftmaxActivation'
                },
            }, [('node_1', 'softmax')])

        pattern = CheckSoftmaxNodeInputs()
        pattern.find_and_replace_pattern(graph)

        node_softmax = Node(graph, 'softmax')

        self.assertEqual(len(node_softmax.in_nodes()), 1)

        node_input1 = node_softmax.in_node(0)
        self.assertEqual(node_input1.name, 'node_1')
Ejemplo n.º 11
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        assert len(connected_in_ports) in [4, 5], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        logit_length_shape = node.in_port(1).data.get_shape()
        labels_shape = node.in_port(2).data.get_shape()
        label_length_shape = node.in_port(3).data.get_shape()
        blank_index_shape = int64_array([])
        if len(node.in_nodes()) == 5:
            blank_index_shape = node.in_port(4).data.get_shape()

        # check shapes of input tensors
        assert len(logits_shape) == 3 and len(logit_length_shape) == 1 and len(labels_shape) == 2\
            and len(label_length_shape) == 1 and len(blank_index_shape) == 0, \
            'Incorrect rank of some input tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], logit_length_shape[0]) and \
               compatible_dims(logits_shape[0], labels_shape[0]) and \
               compatible_dims(logits_shape[0], label_length_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
        assert compatible_dims(logits_shape[1], labels_shape[1]), \
            'Time dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        node.out_port(0).data.set_shape([batch_size])
    def add_unsqueeze_for_new(graph: Graph, ss_node: Node):
        log.info(
            "StridedSlice op with new axis mask '{}' has been detected".format(
                ss_node.id))
        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = mo_array(range(len(ss_node['new_axis_mask'])))[mo_array(
            ss_node['new_axis_mask'], dtype=bool)]
        ss_shape = []
        for i in range(0, len(ss_node['new_axis_mask'])):
            if not ss_node['new_axis_mask'][i]:
                ss_shape.append(shape_out[i])
            else:
                ss_node['new_axis_mask'][i] = 0

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Unsqueeze
        unsqueeze_node = Unsqueeze(graph,
                                   dict(name=ss_node.name +
                                        '/Unsqueeze_new')).create_node()
        ss_node.out_port(0).get_connection().insert_node(unsqueeze_node)
        unsqueeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': unsqueeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
Ejemplo n.º 13
0
    def infer(node: Node):
        # there are limitations coming from ONNX LSTM definition and normalization rules
        assert len(node.in_nodes()) >= 3  # X, W and R
        assert len(node.in_nodes()) <= 7
        assert len(node.out_nodes()) <= 3

        rnn_infer(node, [1, 2])
Ejemplo n.º 14
0
    def infer(node: Node):
        """
         MO input edges:   |   Description:
         -------------------------------------------------
                0          | x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs)
                1          | w: The weight matrix
                2          | b: The bias vector
                3          | h_prev: Previous/initial hidden state
                4          | cs_prev: Value of the initial cell state
         """
        assert len(node.in_nodes()) == 5
        """
        MO output edges:    |   Description:
                0           | cs: Output data / output hidden states concatenated over the whole time sequence
                1           | h: Output cell states concatenated over the whole time sequence
        """

        assert len(node.out_nodes()) in [1, 2]

        mark_input_bins(node)
        input_shape = node.in_node(0).shape

        assert len(input_shape) == 3
        out_shape = input_shape.copy()
        node.out_port(0).data.set_shape(out_shape)
        if node.is_out_port_connected(1):
            node.out_port(1).data.set_shape(out_shape)
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) in [2, 3], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        sequence_len_shape = node.in_port(1).data.get_shape()
        if len(node.in_nodes()) == 3:
            blank_index_shape = node.in_port(2).data.get_shape()
            assert len(blank_index_shape) == 1, \
                'Incorrect rank of blank_index for {} node'.format(node_name)

        # check shapes of input tensors
        assert len(logits_shape) == 3, \
            'Incorrect rank of logits for {} node'.format(node_name)

        assert len(sequence_len_shape) == 1, \
            'Incorrect rank of sequence length tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], sequence_len_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        time_size = logits_shape[1]
        if node.is_out_port_connected(0):
            node.out_port(0).data.set_shape([batch_size, time_size])
        if node.is_out_port_connected(1):
            node.out_port(1).data.set_shape([batch_size])
Ejemplo n.º 16
0
def restore_tensor_names(op: Node):
    for out_port in op.ports:
        # op.ports is our internal attribute, dictionary, where keys are numbers of output ports
        # and values are tuples with shape and tensor name:
        # {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name, out_port_idx_1_rt_info),
        #  out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name, out_port_idx_2_rt_info)}
        out_tensor_names = op.ports[out_port][1]

        # handle Constant operations with old style output port numbering
        if op.soft_get('type') == 'Const':
            assert len(op.ports) == 1, 'Something wrong with Constant node: {}, wrong number ' \
                                       'of output ports: {}!'.format(op.soft_get('name'), len(op.ports))
            out_port = 0

        out_port = out_port - len(op.in_nodes())

        if out_tensor_names is not None:
            # handle tensor names with commas and add them to dictionary as separate items
            if out_tensor_names.find(',') >= 0:
                str_to_replace = '<comma_in_tensor_name>'
                out_tensor_names = (out_tensor_names.replace(
                    '\\,', str_to_replace)).split(',')
                op.out_node(out_port)['fw_tensor_debug_info'] = []
                for out_tensor_name in out_tensor_names:
                    out_tensor_name = out_tensor_name.replace(
                        str_to_replace, ',')
                    op.out_node(out_port)['fw_tensor_debug_info'].append(
                        (out_tensor_name, out_tensor_name))
            else:
                op.out_node(out_port)['fw_tensor_debug_info'] = [
                    (out_tensor_names, out_tensor_names)
                ]
Ejemplo n.º 17
0
    def replace_op(self, graph: Graph, node: Node):
        matmul = MatMul(graph, dict(name=node.name, transpose_b=True)).create_node([node.in_node(0), node.in_node(1)])

        # Bias
        if len(node.in_nodes()) > 2:
            matmul = Add(graph, dict(name=node.name + '/bias')).create_node([matmul, node.in_node(2)])

        return [matmul.id]
Ejemplo n.º 18
0
    def infer(node: Node):
        # check a number of input/output edges
        assert len(node.in_nodes()) == 3
        assert len(node.out_nodes()) == 1

        data_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        segment_ids_shape = node.in_port(2).data.get_shape()
        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        segment_ids_value = node.in_port(2).data.get_value()

        # check input shapes
        assert data_shape is not None, \
            "Shape for input data tensor to SparseSegmentSqrtN must be defined"
        assert indices_shape is not None and indices_shape.size == 1, \
            "SparseSegmentSqrtN supports only 1D indices tensor"
        assert segment_ids_shape is not None and segment_ids_shape.size == 1, \
            "SparseSegmentSqrtN supports only 1D segment IDs tensor"
        assert compatible_shapes(segment_ids_shape, indices_shape), \
            "Indices and segment IDs tensors must have compatible shapes"

        # computes output shape
        output_shape = data_shape
        output_shape[0] = segment_ids_shape[0]
        node.out_port(0).data.set_shape(output_shape)

        # infer if all input is constant
        if data_value is None or indices_value is None or segment_ids_value is None:
            return

        # check that values in segment_ids are sorted
        for i in range(1, len(segment_ids_value)):
            assert segment_ids_value[i-1] <= segment_ids_value[i], \
                "Values in segment IDs are not sorted"
        num_segments = int(segment_ids_value[-1]) + 1

        # check that indices are in a range [0, data_shape[0])
        assert np.all(indices_value >= 0) and np.all(indices_value < data_shape[0]), \
            "Some value in indices tensor is out of range"

        # infer
        num_adds = np.zeros(num_segments, dtype=np.int)
        output_value = np.zeros([num_segments] + data_shape[1:].tolist(),
                                dtype=np.float32)
        output_shape = output_value.shape
        for i in range(len(segment_ids_value)):
            segment_id = int(segment_ids_value[i])
            indice = int(indices_value[i])
            output_value[segment_id, :] += data_value[indice, :]
            num_adds[segment_id] += 1

        num_adds = np.sqrt(num_adds)
        for segment_id in range(num_segments):
            if num_adds[segment_id] != 0:
                output_value[segment_id, :] /= num_adds[segment_id]
        node.out_port(0).data.set_shape(output_shape)
        node.out_port(0).data.set_value(output_value)
Ejemplo n.º 19
0
def get_value_id(node: Node):
    assert node.has_valid('op')
    value_id = None
    for port, in_node in node.in_nodes().items():
        if in_node.has_valid('value'):
            if value_id:
                return None
            value_id = port
    return value_id
Ejemplo n.º 20
0
def get_tensor_id(node: Node):
    assert node.has_valid('op')
    tensor_id = None
    for port, in_node in node.in_nodes().items():
        if not in_node.has_valid('value'):
            if tensor_id:
                return None
            tensor_id = port
    return tensor_id
Ejemplo n.º 21
0
    def control_flow_infer(node: Node, is_executable: bool,
                           mark_executability: callable):
        in_data_nodes = node.in_nodes(control_flow=True)
        out_data_nodes = node.out_nodes(control_flow=True)

        is_executable = any(
            [d.has_and_set('executable') for i, d in in_data_nodes.items(
            )] if len(in_data_nodes) else [False])

        for i, d in out_data_nodes.items():
            mark_executability(d.id, is_executable)
Ejemplo n.º 22
0
def get_fw_tensor_debug_info(node: Node):
    while not node.has_valid('fw_tensor_debug_info') and not node.has_valid('output_sort_order') \
            and len(node.in_nodes()):
        try:
            node = node.in_node()
        except Exception as e:
            log.warning('Was not able to determine tensor debug info for node {}'.format(node.name))
            return "dummy_node_name"
    if node.has_valid('output_sort_order'):
        return node.soft_get('output_sort_order')
    return node.soft_get('fw_tensor_debug_info')
    def add_squeeze_for_shrink(graph: Graph, ss_node: Node):
        # add Squeeze for shrink_axis_mask
        log.info(
            "StridedSlice op with shrink mask '{}' has been detected".format(
                ss_node.id))

        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array(
            ss_node['shrink_axis_mask'], dtype=bool)]
        ss_shape = []
        i = 0
        k = 0

        # Don't permute reshape if channels were squeezed
        dont_permute = graph.graph['layout'] == 'NCHW'
        if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][
                -1] == 1:
            dont_permute = True

        while k < len(shape_out):
            if i >= len(ss_node['shrink_axis_mask']
                        ) or not ss_node['shrink_axis_mask'][i]:
                ss_shape.append(shape_out[k])
                k = k + 1
            else:
                ss_node['shrink_axis_mask'][i] = 0
                ss_shape.append(1)
            i = i + 1

        while i < len(ss_node['shrink_axis_mask']):
            ss_node['shrink_axis_mask'][i] = 0
            ss_shape.append(1)
            i = i + 1

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Squeeze
        squeeze_node = Squeeze(
            graph,
            dict(name=ss_node.name + '/Squeeze_shrink',
                 nchw_layout=dont_permute,
                 correct_data_layout=dont_permute)).create_node()
        ss_node.out_port(0).get_connection().insert_node(squeeze_node)
        squeeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': squeeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(squeeze_node.in_port(1))
Ejemplo n.º 24
0
    def infer(node: Node):
        assert len(node.in_nodes()) == 1
        assert node.fill_value is not None
        assert node.input_as_shape

        shape = node.in_port(0).data.get_value()
        assert shape is not None

        if is_fully_defined(shape):
            node.out_port(0).data.set_value(np.full(shape, node.fill_value, np.float32))
        else:
            node.out_port(0).data.set_shape(shape)
Ejemplo n.º 25
0
    def array_infer(node: Node):
        assert len(node.in_nodes()) == 3

        handle = node.in_node(0)

        ta_node = Node(node.graph, str(handle.value))
        assert ta_node.has_valid('element_shape')

        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = shape_array(
                ta_node['element_shape'])
            node.graph.node[out_node]['value'] = None
Ejemplo n.º 26
0
    def infer(node: Node):
        real_squeeze_dims = int64_array([])
        input_shape = node.in_port(0).data.get_shape()
        node_name = node.soft_get('name', node.id)
        if input_shape is None:
            raise Error(
                'Input shape is not defined for node {}'.format(node_name))

        output_shape = input_shape.copy()
        assert len(node.in_nodes(
        )) == 2, 'The Squeeze node {} must have 2 inputs'.format(node_name)

        # TODO remove the following 'if' statement when IE start support 0D tensors
        squeeze_dims = node.in_port(1).data.get_value()
        if squeeze_dims.ndim == 0:
            squeeze_dims = squeeze_dims.reshape([1])

        for dim in squeeze_dims:
            if output_shape[dim] == 1 or output_shape[dim] is dynamic_dimension:
                real_squeeze_dims = np.ma.append(
                    real_squeeze_dims,
                    get_canonical_axis_index(output_shape, dim))
            else:
                raise Error(
                    'Trying to squeeze dimension not equal to 1 for node "{}"'.
                    format(node_name))

        # if squeeze_dims empty then all 1s should be removed (tf specification of Squeeze op)
        if squeeze_dims.size == 0:
            for i in range(output_shape.size):
                if output_shape[i] == 1:
                    real_squeeze_dims = np.ma.append(
                        real_squeeze_dims,
                        get_canonical_axis_index(output_shape, i))

        assert is_fully_defined(
            real_squeeze_dims
        ), 'Squeeze dimension(s) is not defined for op "{}"'.format(node_name)
        output_shape = shape_delete(output_shape, real_squeeze_dims)
        node.out_port(0).data.set_shape(output_shape)

        # make dimensions positive to correctly translate from NHWC to NCHW layout
        if node.in_port(1).get_source().node.op == 'Const':
            node.in_port(1).data.set_value(real_squeeze_dims)

        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(
                node.in_port(0).data.get_value().reshape(output_shape))

        # the squeeze_dim attribute will be converted to the second input in the end of the Middle phase
        PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0',
                                              'axis')
Ejemplo n.º 27
0
def find_input_port(node: Node, input_desc: list, search_node_name: str,
                    search_node_port: int):
    if input_desc is None:
        return len(node.in_nodes())

    for in_port, tensor_desc in enumerate(input_desc):
        for node_pattern, node_port in tensor_desc:
            if findall(node_pattern,
                       search_node_name) and node_port == search_node_port:
                return in_port
    raise Exception(
        'Did not find input port of the node "{}" with port "{}"'.format(
            search_node_name, search_node_port))
Ejemplo n.º 28
0
    def array_infer(node: Node):
        assert len(node.in_nodes()) == 2

        handle = node.in_node(0)

        ta_node = Node(node.graph, str(handle.value))
        assert ta_node.has_valid('size')

        output_value = mo_array(ta_node['size'])

        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = shape_array(output_value.shape)
            node.graph.node[out_node]['value'] = output_value.copy()
Ejemplo n.º 29
0
def convert_const_node_value_type(const_node: Node, np_data_type):
    assert const_node.type == 'Const'
    log.warning('Converting type of Const node "{}" to "{}"'.format(const_node.name, np_data_type))
    const_node.value = const_node.value.astype(np_data_type)
    const_node.data_type = np_data_type
    const_node.infer(const_node)
    const_node.type_infer(const_node)

    # if the Const node has an input data node then need to update it also
    if len(const_node.in_nodes()) == 1:
        input_data = const_node.in_node(0)
        assert input_data.kind == 'data'
        input_data.value = input_data.value.astype(const_node.data_type)
        input_data.data_type = const_node.data_type
Ejemplo n.º 30
0
    def infer(node: Node):
        assert (len(node.in_nodes()) == 3), 'MaxPoolV2 node {} from must have only 3 inputs: input, window size, and ' \
                                            'strides but instead got {} inputs'.format(node.soft_get('name', node.id),
                                                                                       len(node.in_nodes()))
        node['window'] = node.in_port(1).data.get_value()
        node['stride'] = node.in_port(2).data.get_value()

        if node['window'] is None:
            raise Error(
                'The non-constant window size for MaxPoolV2 node {} is not supported'
                ''.format(node.soft_get('name', node.id)))
        if node['stride'] is None:
            raise Error(
                'The non-constant strides for MaxPoolV2 node {} is not supported'
                ''.format(node.soft_get('name', node.id)))

        Pooling.pool_infer(node)