示例#1
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(type='StridedSlice'):
            StridedSliceNormalizer.normalize_strided_slice(graph, node)
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('begin_mask',
                     'input:0'),  # but indeed depends from slice_rank
                    ('end_mask', 'input:0'),
                    ('new_axis_mask', 'input:0'),
                    ('shrink_axis_mask', 'input:0'),
                    ('ellipsis_mask', 'input:0')
                ])

            # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
            # Until now it was not possible to set correct permutations
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:1', 'slice',
                                                  'dim_size')
            PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                  'input:2', 'slice',
                                                  'dim_size')
            if node.is_in_port_connected(3):
                PermuteInputs().set_input_permutation(node.in_node(3), node,
                                                      'input:3', 'slice',
                                                      'dim_size')
示例#2
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(type='StridedSlice'):
            StridedSliceNormalizer.normalize_strided_slice(graph, node)
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('begin_mask',
                     'input:0'),  # but indeed depends from slice_rank
                    ('end_mask', 'input:0'),
                    ('new_axis_mask', 'input:0'),
                    ('shrink_axis_mask', 'input:0'),
                    ('ellipsis_mask', 'input:0')
                ])

            # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
            # Until now it was not possible to set correct permutations
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:1', 'slice',
                                                  'dim_size')
            PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                  'input:2', 'slice',
                                                  'dim_size')
            if node.is_in_port_connected(3):
                PermuteInputs().set_input_permutation(node.in_node(3), node,
                                                      'input:3', 'slice',
                                                      'dim_size')

            # If there are new_axis_mask or shrink_axis_mask then StridedSlice should be performed in the
            # original layout, same as for Squeeze, Unsqueeze, Reshape, Gather
            if np.count_nonzero(node['new_axis_mask']) > 0 or np.count_nonzero(
                    node['shrink_axis_mask']) > 0:
                node['reinterp_shape'] = True
                node['nchw_layout'] = True
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_data_nodes():
            if node.has_and_set('nchw_layout'):
                continue

            # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
            permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape))

            # Check that data node already has permutation
            skip_permutation = False
            for in_node in node.in_nodes():
                edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True
            for out_node in node.out_nodes():
                edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True

            if skip_permutation:
                continue

            # Set permutation to all in/out edges
            for in_node in node.in_nodes():
                PermuteAttrs.set_permutation(in_node, node, permutation)

            for out_node in node.out_nodes():
                PermuteAttrs.set_permutation(node, out_node, permutation)
示例#4
0
    def _one_input_infer(node: Node):
        input_shape = node.in_port(0).data.get_shape()
        node_name = node.soft_get('name', node.id)
        if input_shape is None:
            raise Error('input_shape is none for {} node'.format(node_name))

        if not node.has_valid('axis'):
            raise Error('axis attribute is missing for {} node. should be set in crop extractor'.format(node_name))

        output_shape = input_shape.copy()
        if node.has_valid('dim'):
            if len(node.dim) != len(node.axis):
                raise Error('Number of axis "{}" should match number of dim "{}" for node "{}"'
                            ''.format(node.axis, node.dim, node_name))
            output_shape[node.axis] = node.dim
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis):
                raise Error('number of crop_begin({})/crop_end({}) should match number of axis "{}" for node "{}"'
                            ''.format(node.crop_begin, node.crop_end, node.axis, node_name))
            if type(node.axis) in [list, tuple]:
                for i in range(len(node.axis)):
                    output_shape[node.axis[i]] = output_shape[node.axis[i]] - node.crop_begin[i] - node.crop_end[i]
            else:
                output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end
        else:
            raise Error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node_name))

        node.out_port(0).data.set_shape(output_shape)
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
示例#5
0
    def infer(node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 1 and 0 in connected_in_ports, \
            "AttributedTile should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        shape = node.in_port(0).data.get_shape()
        assert shape is not None, "Undefined input shape for AttributedTile node '{}'.".format(
            name)
        axis = node.soft_get('axis', None)
        assert axis is not None
        tiles = node.soft_get('tiles', None)
        assert tiles is not None, "Undefined `tiles` attribute of Tile node '{}'".format(
            name)

        tile_array = int64_array(np.ones(shape.size))
        tile_array[node.axis] = node.tiles

        node.out_port(0).data.set_shape(shape * tile_array)
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(
                np.tile(node.in_port(0).data.get_value(), tile_array))

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
示例#6
0
    def infer(node):
        in_ports = node.in_ports()
        connected_ports = [
            port for port in in_ports.values() if not port.disconnected()
        ]
        assert len(connected_ports) == 2, 'The number of inputs to the TopK layer name "{}" must be equal to 2.' \
                                          ''.format(node.soft_get('name'))

        k = node.in_port(1).data.get_value()
        if k is None:
            raise Error(
                'The value defining number of output elements for layer "{}" is not defined'
                ''.format(node.soft_get('name')))
        assert node.has_valid(
            'axis'), 'The "axis" attribute is not defined for node {}'.format(
                node.name)

        input_shape = node.in_port(0).data.get_shape()
        node.axis = len(
            input_shape) + node.axis if node.axis < 0 else node.axis
        output_shape = input_shape.copy()
        output_shape[node.axis] = k

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        # setting shape and value if applicable
        if not node.out_port(0).disconnected():
            node.out_port(0).data.set_shape(output_shape)
        if not node.out_port(1).disconnected():
            node.out_port(1).data.set_shape(output_shape)
        if node.in_port(0).data.get_value() is not None:
            # TODO implement value propagation
            pass
示例#7
0
def arg_ops_infer(node: Node):
    shape = node.in_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)
    assert shape is not None, "Input shape for the node {} is None".format(node_name)

    # there are two inputs in TensorFlow. The second input is the axis for ArgMax
    connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
    if len(connected_in_ports) == 2:
        axis = node.in_port(1).data.get_value()
        if axis is None:
            log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id)))
            return
        node.axis = axis
        # remove the unnecessary input
        node.in_port(1).disconnect()

    num_top_axes = shape.size
    if num_top_axes < 3:
        num_top_axes = 3

    out_shape = np.ones(num_top_axes, dtype=np.int64)

    if node.has_valid('axis'):
        axis = get_canonical_axis_index(shape, node.axis)
        node.axis = axis
        out_shape = shape.copy()
        out_shape[axis] = node.top_k
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
    else:
        out_shape[0] = shape[0]
        out_shape[2] = node.top_k
        if node.has_and_set('out_max_val'):
            out_shape[1] = 2

    node.out_port(0).data.set_shape(out_shape)
示例#8
0
    def infer(node):
        name = node.soft_get('name', node.id)

        op = node.soft_get('op', None)
        assert op is not None and op in ['Split', 'AttributedSplit'], \
            'Unexpected `op`={} attribute for Split-like node {}'.format(op, name)

        num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None
        assert num_in_ports in [1, 2], \
            'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \
            ''.format(num_in_ports, op, name)

        connected_inputs = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \
            "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \
            "".format(op, num_in_ports, name, connected_inputs)

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None, 'Input shape is unknown for node {}'.format(
            name)
        assert node.has_valid(
            'num_splits'
        ), 'Parameter `num_splits` is unknown for node {}'.format(name)
        num_splits = node.num_splits

        axis = node.in_port(1).data.get_value(
        ) if op == 'Split' else node.soft_get('axis', None)
        assert axis is not None, '{} `axis` is unknown for node {}'.format(
            op, name)
        assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format(
            op, name)

        assert not is_fully_defined(input_shape[axis]) or input_shape[axis] % num_splits == 0, \
            'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \
            '`num_splits`={}'.format(op, name, input_shape, axis, num_splits)

        out_shape = input_shape.copy()
        out_shape[axis] = input_shape[axis] // num_splits

        input_value = node.in_port(0).data.get_value()
        output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \
            if input_value is not None else None

        for idx, port in node.out_ports().items():
            if idx in node.out_nodes():
                port.data.set_shape(out_shape)
                if output_value is not None:
                    port.data.set_value(output_value[idx])

        if op == 'Split':
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:0', 'axis')
        elif op == 'AttributedSplit':
            PermuteAttrs.create_permute_attrs(node,
                                              attrs=[('axis', 'input:0')])
示例#9
0
    def infer(node):
        name = node.soft_get('name', node.id)
        assert node.has_valid('shape'), \
            'Parameter node {} should have `shape` attribute. Please use cli options to set model input shape' \
            ''.format(name)
        node.out_port(0).data.set_shape(node.shape)

        PermuteAttrs.create_permute_attrs(node, attrs=[('shape', 'output:0')])
示例#10
0
 def infer(node: Node):
     assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
         'LogSoftmax node with id {} have more than one port connected'.format(node.id)
     if node.axis < 0:
         node.axis = len(node.in_port(0).data.get_shape()) + node.axis
     assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
         'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
     copy_shape_infer(node)
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
    def test_transpose_insert_with_two_result_nodes(self, nhwc_to_nchw_order, nchw_to_nhwc_order,
                                                    add_permutation_attrs, fft_kind):
        shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3
        shape = np.array(range(shape_len))
        add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order]
        graph = build_graph(nodes_attrs=nodes_for_case_with_two_results,
                            edges=edges_for_case_with_two_results,
                            update_attributes={
                                'placeholder1_data': {'shape': int64_array(shape)},
                                'placeholder1': {'shape': int64_array(shape), 'rt_info': RTInfo()},
                                'transpose_parameter_order': {
                                    'value': np.array(nhwc_to_nchw_order),
                                    'shape': int64_array(np.array(nhwc_to_nchw_order).shape)
                                },
                                'transpose_parameter_order_data': {
                                    'value': np.array(nhwc_to_nchw_order),
                                    'shape': int64_array(np.array(nhwc_to_nchw_order).shape)
                                },
                                'fft': {'op': fft_kind, 'type': fft_kind},
                                'add_data': {'shape': add_shape},
                                'fft_data': {'shape': add_shape},
                                'result1': {'shape': shape, 'rt_info': RTInfo()},
                                'result2': {'shape': shape, 'rt_info': RTInfo()},
                            })

        if add_permutation_attrs:
            graph_ref = build_graph(nodes_for_case_with_two_results, edges_with_transpose_for_case_with_two_results)
        else:
            graph_ref = build_graph(nodes_for_case_with_two_results, edges_for_case_with_two_results)

        param1_node = Node(graph, 'placeholder1')
        result1_node = Node(graph, 'result1')
        result2_node = Node(graph, 'result2')

        if add_permutation_attrs:
            shape_len = len(nhwc_to_nchw_order)
            param1_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
            param1_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
            result1_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
            result2_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)

        PreserveRuntimeInfo().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result1')
        self.assertTrue(flag, resp)

        self.assertFalse(param1_node.has_valid('permute_attrs'))
        self.assertFalse(param1_node.out_node(0).has_valid('permutation'))

        if add_permutation_attrs:
            rt_info = param1_node.rt_info.info
            old_api_map = rt_info[('old_api_map_order', 0)].info
            self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order))
示例#12
0
    def infer(node: Node):
        node['order'] = list(range(node.in_node().shape.size))
        node.order[node.dim2], node.order[node.dim1] = node.order[
            node.dim1], node.order[node.dim2]

        input_shape = node.in_port(0).data.get_shape().copy()
        node.out_port(0).data.set_shape(input_shape[node.order])
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(
                np.transpose(node.in_port(0).data.get_value(),
                             axes=node.order))

        PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
示例#13
0
    def infer(node):
        input_data_shape = node.in_port(0).data.get_shape()
        assert input_data_shape is not None
        assert node.has_valid('seq_axis')
        assert node.has_valid('batch_axis')

        assert len(node.out_nodes()) == 1
        node.out_port(0).data.set_shape(input_data_shape)

        PermuteAttrs.create_permute_attrs(node,
                                          attrs=[('seq_axis', 'input:0')])
        PermuteAttrs.create_permute_attrs(node,
                                          attrs=[('batch_axis', 'input:0')])
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        # IE DetectionOutput layer consumes flattened confidences and locations tensors.
        # That is why we add reshapes before them.
        locs_node = match.single_input_node(0)
        conf_node = match.single_input_node(1)
        prior_boxes_node = match.single_input_node(2)

        locs_out_nodes = locs_node[0].out_nodes()
        assert len(locs_out_nodes) == 1
        locs_out_node = locs_out_nodes[list(locs_out_nodes.keys())[0]]
        assert locs_out_node.op == "Result", locs_out_node.op
        graph.remove_node(locs_out_node.id)

        conf_out_nodes = conf_node[0].out_nodes()
        assert len(conf_out_nodes) == 1
        conf_out_node = conf_out_nodes[list(conf_out_nodes.keys())[0]]
        assert conf_out_node.op == "Result", conf_out_node.op
        graph.remove_node(conf_out_node.id)

        # reshape operation to flatten confidence tensor
        const = Const(graph, {'value': int64_array([0, -1])}).create_node()
        reshape_loc_node = Reshape(graph, {}).create_node(
            [locs_node, const], dict(name='DetectionOutput_Reshape_loc_'))

        # reshape operation to flatten confidence tensor
        reshape_conf_node = Reshape(graph, {}).create_node(
            [conf_node, const], dict(name='DetectionOutput_Reshape_conf_'))

        # remove the Result node after the priors node
        assert prior_boxes_node[0].out_node().op == "Result"
        graph.remove_node(prior_boxes_node[0].out_node().id)

        # reshape operation for prior boxes tensor
        const = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        reshape_priors_node = Reshape(graph, {}).create_node(
            [prior_boxes_node, const],
            dict(name='DetectionOutput_Reshape_priors_'))
        # create Detection Output node with three inputs: locations, confidences and prior boxes
        detection_output_op = DetectionOutput(
            graph, match.custom_replacement_desc.custom_attributes)
        detection_output_node = detection_output_op.create_node(
            [reshape_loc_node, reshape_conf_node, reshape_priors_node],
            dict(name=detection_output_op.attrs['type'] + '_'))
        PermuteAttrs.set_permutation(reshape_priors_node,
                                     detection_output_node, None)

        # create Output node to mark DetectionOutput as a graph output operation
        output_op = Result(graph)
        output_op.create_node([detection_output_node], dict(name='sink_'))
        return {}
示例#15
0
    def reorgyolo_infer(node: Node):
        input_shape = node.in_node(0).shape
        if input_shape is None:
            raise Error('Input shape for operation "{}" is None'.format(node.soft_get('name', node.id)))

        stride = node.stride

        output_shape = input_shape.copy()
        output_shape[node.batch_dims] = input_shape[node.batch_dims]  # pylint: disable=unsupported-assignment-operation
        output_shape[node.channel_dims] = input_shape[node.channel_dims] * stride ** 2  # pylint: disable=unsupported-assignment-operation
        # Round as in caffe
        output_shape[node.spatial_dims] = np.ma.round(input_shape[node.spatial_dims] / stride)  # pylint: disable=unsupported-assignment-operation

        node.out_port(0).data.set_shape(output_shape)
        PermuteAttrs.create_permute_attrs(node, attrs=[('channel_dims', 'input:0'), ('spatial_dims', 'input:0')])
示例#16
0
    def extract(cls, node):
        shape = shape_array([])
        # Extract output shape from `shape` attribute
        extracted_shape = tf_tensor_shape(node.pb.attr["shape"].shape)
        if len(extracted_shape) != 0:
            shape = extracted_shape
        else:
            # Extract output shape from `_output_shapes` attribute if it is possible
            extracted_output_shapes = node.pb.attr["_output_shapes"].list.shape
            if len(extracted_output_shapes) == 1:   # check if attribute not empty
                extracted_output_shapes = tf_tensor_shape(extracted_output_shapes[0])

                # Check equality of extracted shapes. We know some cases then Placeholder operation has empty `shape`
                # attribute value and non-empty `_output_shapes` attribute value and need co handle and support it.
                if len(extracted_output_shapes) > len(extracted_shape):
                    log.warning('Extracted shapes for Placeholder operation {} have different lengths: `shape` {} and '
                                '`_output_shapes` {}. Please, check if model is consistent'.format(
                        node.pb.name, extracted_shape, extracted_output_shapes))
                    if len(extracted_output_shapes) != 0:
                        shape = extracted_output_shapes

        attrs = {
            'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
            'shape': shape,
            'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
        }
        if node.pb.attr["shape"].shape.unknown_rank:
            attrs['shape'] = None
        Parameter.update_node_stat(node, attrs)
        return cls.enabled
示例#17
0
    def extract(cls, node):
        attrs = tf_create_attrs(node, 2, 3)

        def get_num_groups(node):
            if 'group' in node:
                return node.group
            elif node.in_node(0).shape is not None and node.kernel_shape is not None \
                    and node.in_node(0).shape[node.channel_dims[0]] is not dynamic_dimension \
                    and node.kernel_shape[node.input_feature_channel] is not dynamic_dimension:
                # if group attribute is not defined, number of groups is calculated
                # from number of input channels and filter channel size
                return node.in_node(0).shape[
                    node.channel_dims] // node.kernel_shape[
                        node.input_feature_channel]
            else:
                return 1

        attrs.update({
            'op':
            __class__.op,
            'get_group':
            get_num_groups,
            'get_output_feature_dim':
            lambda node: node.kernel_shape[node.output_feature_channel],
            'get_weights_permute':
            PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
                                     inv=int64_array([2, 3, 1, 0]))
        })

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return cls.enabled
    def test_transpose_insert(self, nhwc_to_nchw_order, nchw_to_nhwc_order, add_permutation_attrs):
        graph_nodes = {
            **valued_const_with_data('transpose_parameter_order', np.array(nhwc_to_nchw_order)),
            **valued_const_with_data('transpose_result_order', np.array(nchw_to_nhwc_order))
        }
        graph_nodes.update(nodes)
        shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3
        shape = np.array(range(shape_len))
        add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order]
        graph_nodes.update(
            {
                **regular_op_with_shaped_data('placeholder1', shape,
                                              {'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
                **regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}),
                **regular_op_with_shaped_data('add', add_shape,
                                              {'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}),
            }
        )

        graph = build_graph(graph_nodes, edges)
        graph_ref = build_graph(graph_nodes, edges_with_transpose if add_permutation_attrs else edges)

        param_node = Node(graph, 'placeholder1')
        result_node = Node(graph, 'result')

        if add_permutation_attrs:
            shape_len = len(nhwc_to_nchw_order)
            param_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
            param_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
            result_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)

        PreserveRuntimeInfo().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        self.assertFalse(param_node.has_valid('permute_attrs'))
        self.assertFalse(param_node.out_node(0).has_valid('permutation'))

        if add_permutation_attrs:
            rt_info = param_node.rt_info.info
            old_api_map = rt_info[('old_api_map_order', 0)].info
            self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order))

            rt_info = result_node.rt_info.info
            old_api_map = rt_info[('old_api_map_order', 0)].info
            self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
def reverse_permute(output_shape: np.array, order: np.array):
    """
    Calculates Transpose op input shape based on output shape and permute order.
    :param output_shape: Transpose output shape
    :param order: permute order
    :return: Transpose input shape corresponding to the specified output shape
    """
    return int64_array(output_shape[PermuteAttrs.get_inverse_permutation(order)])
示例#20
0
def infer_for_opset1(node: Node):
    assert len([p for p in node.in_ports().values() if not p.disconnected()]) == 2
    assert node.has_valid('mode')
    assert node.has_valid('axes')

    src_shape = node.in_port(0).data.get_shape()

    assert src_shape is not None
    dst_shape = node.in_port(1).data.get_value()
    assert dst_shape is not None

    output_shape = src_shape.copy()
    for ind, axis in enumerate(node.axes):
        output_shape[axis] = dst_shape[ind]

    node.out_port(0).data.set_shape(output_shape)

    PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
示例#21
0
 def extract(cls, node):
     attrs = {
         'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
         'shape': tf_tensor_shape(node.pb.attr["shape"].shape),
         'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
     }
     if node.pb.attr["shape"].shape.unknown_rank:
         attrs['shape'] = None
     Parameter.update_node_stat(node, attrs)
     return cls.enabled
示例#22
0
    def infer(node: Node):
        data_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        axis = node.axis
        data_rank = len(data_shape)

        assert data_rank >= 1, 'data_rank must be >= 1'
        assert data_rank == len(indices_shape), 'data and indices inputs for node {} must be of the ' \
                                                'same rank. Instead got {} and {}'. \
            format(node.name, data_rank, len(indices_shape))
        assert -data_rank <= axis < data_rank, 'axis for node {0} must be within interval ' \
                                               '[-{1},  {1} - 1]. Instead got: axis={2}'. \
            format(node.name, data_rank, axis)
        if axis < 0:
            axis += data_rank
        out_shape = indices_shape.copy()
        for idx, (data_sz, ind_sz) in enumerate(zip(data_shape,
                                                    indices_shape)):
            out_shape[
                idx] = ind_sz if ind_sz is not dynamic_dimension or idx == axis else data_sz
            if idx != axis and data_sz != ind_sz:
                raise Error(
                    'Sizes along axis {} for node {} do not match. data and indices must have '
                    'equal size along all axes except for axis {}'.format(
                        idx, node.name, axis))

        data = node.in_port(0).data.get_value()
        indices = node.in_port(1).data.get_value()

        if data is not None and indices is not None:
            out_value = np.empty(indices_shape, dtype=data.dtype)
            for idx in np.ndindex(*indices_shape):
                data_idx = list(idx)
                data_idx[node.axis] = indices[idx]
                out_value[idx] = data[tuple(data_idx)]
            node.out_port(0).data.set_value(out_value)
        else:
            node.out_port(0).data.set_shape(out_shape)

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
示例#23
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        axis = node.soft_get('axis', None)
        assert axis is not None

        data_shape = node.in_port(0).data.get_shape()
        assert data_shape is not None
        indices_shape = node.in_port(1).data.get_shape()
        assert indices_shape is not None

        # Convert negative axis
        axis = get_canonical_axis_index(data_shape, axis)
        node.axis = axis

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        if data_value is not None and indices_value is not None:
            node.out_port(0).data.set_value(
                mo_array(np.take(data_value, indices_value, axis),
                         dtype=data_value.dtype))
            return

        shape = np.concatenate((data_shape[:axis], indices_shape))
        if axis < len(data_shape) - 1:
            shape = np.concatenate((shape, data_shape[axis + 1:]))

        node.out_port(0).data.set_shape(int64_array(shape))
示例#24
0
    def extract(cls, node):
        attrs = tf_create_attrs(node, 4, 3)
        attrs.update({
            'op':
            cls.op,
            'get_weights_permute':
            PermuteAttrs.Permutation(perm=int64_array([4, 3, 0, 1, 2]),
                                     inv=int64_array([2, 3, 4, 1, 0])),
            'swap_0_and_2_inputs':
            True,
            'shape_input':
            True,
        })

        # update the attributes of the node
        Deconvolution.update_node_stat(node, attrs)
        return cls.enabled
示例#25
0
    def extract(cls, node):
        attrs = tf_create_attrs(node, 3, 4)
        attrs.update({
            'op':
            __class__.op,
            'get_group':
            lambda node: 1,
            'get_output_feature_dim':
            lambda node: node.kernel_shape[node.output_feature_channel],
            'get_weights_permute':
            PermuteAttrs.Permutation(perm=int64_array([4, 3, 0, 1, 2]),
                                     inv=int64_array([2, 3, 4, 1, 0]))
        })

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return cls.enabled
示例#26
0
    def extract(cls, node):
        attrs = tf_create_attrs(node, 2, 2)
        attrs.update({
            'op':
            __class__.op,
            'kernel_spatial_idx':
            np.array([0, 1], dtype=np.int64),
            'get_group':
            lambda node: node.kernel_shape[node.output_feature_channel],
            'get_output_feature_dim':
            lambda node: node.kernel_shape[-1] * node.kernel_shape[-2],
            'get_weights_permute':
            PermuteAttrs.Permutation(perm=int64_array([2, 3, 0, 1]),
                                     inv=int64_array([2, 3, 0, 1]))
        })

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return cls.enabled
示例#27
0
    def extract(cls, node):
        attrs = tf_create_attrs(node, 2, 3)
        attrs.update({
            'op':
            __class__.op,
            'get_group':
            lambda node: node.group
            if 'group' in node and node.group is not None else node.in_node(0).
            shape[node.channel_dims] // node.kernel_shape[
                node.input_feature_channel],
            'get_output_feature_dim':
            lambda node: node.kernel_shape[node.output_feature_channel],
            'get_weights_permute':
            PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
                                     inv=int64_array([2, 3, 1, 0]))
        })

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if not node.has_port('in', 2) or node.in_port(2).disconnected() or not node.has_and_set('shape_input'):
            return

        if node.has_valid('layout') and not node.layout.startswith('NC') and graph.graph['layout'] == 'NCHW':
            input_shape_rank = len(node.in_port(0).data.get_shape())
            permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(input_shape_rank)

            data_node = node.in_node(2)

            name = node.soft_get('name', node.id) + '/ShapeGather'
            const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
                                  'need_shape_inference': True}).create_node_with_data()
            axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
            gather = Gather(graph, {'name': name,
                                    'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
            attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

            graph.add_edge(gather.id, node.id, **attrs)
            graph.remove_edge(data_node.id, node.id)
示例#29
0
def strided_slice(op_node: Node, port_info: str, input_port: int):
    """
    StridedSLice must be permuted even if input or output tensors have rank lesser than 4
    e.g. input_shape = (1, 10, 10), out = input[:, 0:10, :, new_axis], input_rank < 4
    input_shape = (1, 10, 10, 3), out = input[:, 0:5, 0:4, 0], output_rank < 4
    in both examples slice_rank is >= 4
    slice_rank is defined by length of begin, end, strides (they all are of the same length)
    """
    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permute_indices_for_gather = permutation_data_node.permutation.perm
    if len(permute_indices_for_gather) == 0:
        return
    from openvino.tools.mo.ops.op import PermuteAttrs

    slice_rank = op_node.in_port(input_port).data.get_shape()[
        0]  # length of begin, end or strides
    permute_indices_for_gather = PermuteAttrs.get_nhwc_to_nchw_permutation(
        slice_rank).perm
    reorder_inputs_for_shape_or_slice(op_node, input_port,
                                      permute_indices_for_gather)
示例#30
0
def convert_graph_inputs_to_parameters(internal_graph, internal_graph_proto):
    # create Parameter nodes for the body graph
    body_parameters = []
    body_parameter_names = []
    for idx, pb_node in enumerate(internal_graph_proto['input_arg']):
        param_id = internal_graph.unique_id(pb_node.name)
        internal_graph.add_node(param_id,
                                name=param_id,
                                kind='op',
                                op='Parameter',
                                pb=None,
                                shape=None)
        parameter_node = Node(internal_graph, pb_node.name)
        Parameter.update_node_stat(
            parameter_node, {
                'data_type':
                tf_dtype_extractor(pb_node.type),
                'permute_attrs':
                PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
            })
        body_parameters.append(parameter_node)
        body_parameter_names.append(param_id)
    return body_parameters, body_parameter_names