예제 #1
def transpose_infer(node):
    if node.order is None and (not node.has_valid('reverse_order') or
                                and node.reverse_order == False)):
        log.error('Cannot infer {} because order is None'.format(

    if node.has_valid('reverse_order'
                      ) and node.reverse_order and node.has_valid('order'):
            'Cannot infer {} due to both order and reverse_order was set'.

    input_shape = node.in_node(0).shape

    if node.has_valid('reverse_order') and node.reverse_order:
        node.order = np.arange(len(input_shape))[::-1]  # Reverse order

    output_shape = np.array([input_shape[i] for i in node.order],
    node.out_node(0).shape = output_shape
    if node.in_node().has_valid('value'):
        node.out_node().value = np.transpose(node.in_node().value,
    PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
예제 #2
    def infer(node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        assert len(connected_in_ports) == 1 and 0 in connected_in_ports, \
            "AttributedTile should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        shape = node.in_port(0).data.get_shape()
        assert shape is not None, "Undefined input shape for AttributedTile node '{}'.".format(
        axis = node.soft_get('axis', None)
        assert axis is not None
        tiles = node.soft_get('tiles', None)
        assert tiles is not None, "Undefined `tiles` attribute of Tile node '{}'".format(

        tile_array = int64_array(np.ones(shape.size))
        tile_array[node.axis] = node.tiles

        node.out_port(0).data.set_shape(shape * tile_array)
        if node.in_port(0).data.get_value() is not None:
                np.tile(node.in_port(0).data.get_value(), tile_array))

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
예제 #3
    def test_permute_begin_end_ellipsis(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice'),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice'),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice'),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice'),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'begin': {'value': [0, 1], 'shape': [2]},
                             'end': {'value': [1, 0], 'shape': [2]},
                             'stride': {'value': [1, 2], 'shape': [2]},
                             'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
                                               'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([1, 0])},
                             'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},

        slice_node = Node(graph, 'strided_slice')
        slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                              slice_node['shrink_axis_mask'], 4,
                                                                              list(slice_node['begin_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
        self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0])))

        slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                            slice_node['shrink_axis_mask'], 4,
                                                                            list(slice_node['end_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
        self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_data_nodes():
            if node.has_and_set('nchw_layout'):

            # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
            permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(

            # Check that data node already has permutation
            skip_permutation = False
            for in_node in node.in_nodes():
                edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True
            for out_node in node.out_nodes():
                edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True

            if skip_permutation:

            # Set permutation to all in/out edges
            for in_node in node.in_nodes():
                PermuteAttrs.set_permutation(in_node, node, permutation)

            for out_node in node.out_nodes():
                PermuteAttrs.set_permutation(node, out_node, permutation)
예제 #5
    def _one_input_infer(node: Node):
        input_shape = np.array(node.in_node().shape)

        if input_shape is None:
            log.error('input_shape is none for {} node'.format(node.name))

        if not node.has_valid('axis'):
            log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))

        output_shape = input_shape
        if node.has_valid('dim'):
            if len(node.dim) != len(node.axis):
                log.error('number of axis should match number of dim')
            output_shape[node.axis] = node.dim
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis):
                log.error('number of crop_begin/crop_end should match number of axis')
            if type(node.axis) in [list, tuple]:
                for i in range(len(node.axis)):
                    output_shape[node.axis[i]] = output_shape[node.axis[i]] - node.crop_begin[i] - node.crop_end[i]
                output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end
            log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name))

        node.out_node().shape = np.array(output_shape)
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
예제 #6
def caffe_inner_product(node):
    input_shape = node.in_node(0).shape
    if input_shape is None:
    batches = input_shape[0]
    input_channels = np.prod(input_shape[1:])
    if not node.has_valid('out-size'):
        node['out-size'] = (np.prod(node.in_node(1).shape) /
    output_channels = node['out-size']

    weights_shape = np.array([output_channels, input_channels], dtype=np.int64)

    # In case if original weight layout is IO we transpose them
    if np.array_equal(node.in_node(1).shape, weights_shape[::-1]
                      ) and node.soft_get('transpose_weights') is True:
        node.in_node(1).value = np.transpose(node.in_node(1).value)

    node.out_node().shape = np.array([batches, output_channels],
    # Back propagation of shape to weights
    node.in_node(1).shape = np.array(weights_shape)
    node.in_node(1).value.shape = node.in_node(1).shape

    assign_dims_to_weights(node.in_node(1), None, 1, 0, 2)
    PermuteAttrs.set_permutation(node.in_node(1), node, None)
예제 #7
    def infer(node):
        in_ports = node.in_ports()
        connected_ports = [port for port in in_ports.values() if not port.disconnected()]
        assert len(connected_ports) == 2, 'The number of inputs to the TopK layer name "{}" must be equal to 2.' \

        k = node.in_port(1).data.get_value()
        if k is None:
            raise Error('The value defining number of output elements for layer "{}" is not defined'
        assert node.has_valid('axis'), 'The "axis" attribute is not defined for node {}'.format(node.name)

        input_shape = node.in_port(0).data.get_shape()
        node.axis = len(input_shape) + node.axis if node.axis < 0 else node.axis
        output_shape = input_shape.copy()
        output_shape[node.axis] = k

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        # setting shape and value if applicable
        if not node.out_port(0).disconnected():
        if not node.out_port(1).disconnected():
        if node.in_port(0).data.get_value() is not None:
            # TODO implement value propagation
예제 #8
파일: argmax.py 프로젝트: cooper-chiou/dldt
    def argmax_infer(node: Node):
        shape = node.in_node(0).shape
        if shape is None:

        # there are two inputs in TensorFlow. The second input is the axis for ArgMax
        if len(node.in_nodes()) == 2:
            if node.in_node(1).value is None:
                log.debug('The second argument to ArgMax is None')
            node.axis = node.in_node(1).value.item()
            # remove the unnecessary input
            node.graph.remove_edge(node.in_node(1).id, node.id)

        num_top_axes = shape.size
        if num_top_axes < 3:
            num_top_axes = 3

        out_shape = np.ones(num_top_axes, dtype=int)

        if node.has_valid('axis'):
            axis = get_canonical_axis_index(shape, node.axis)
            node.axis = axis
            out_shape = np.array(shape)
            out_shape[axis] = node.top_k
                                              attrs=[('axis', 'input:0')])
            out_shape[0] = shape[0]
            out_shape[2] = node.top_k
            if node.out_max_val:
                out_shape[1] = 2

        node.out_node().shape = out_shape
예제 #9
def tf_split_infer(node):
    Partial infer of split node similar to Split op of TF.
    # Two inputs: [split_dim, input]
    assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name'))
    split_dim = node.in_node(0).value
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')

    assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name'))
    split_dim = split_dim.item()
    input = node.in_node(1)

    if input.shape is None:
        log.error('Input shape for node {} is not defined'.format(node.soft_get('name')))

    log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
    split_dim_size = input.shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")

    # split_dim is a numpy array, axis is split_dim[0]
    log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(
        split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split)))
    split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
    node.graph.remove_edge(node.in_node(0).id, node.id)
    node['input_port'] = 1

    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
예제 #10
def conv_flatten_concat_action(graph: Graph, match: dict):
    assert graph.graph['layout'] == 'NHWC'
    reshape_node = match['reshape']
    reshape_data_node = match['reshape_data']
    conv_name = match['conv'].name
    conv_data_node = match['conv_data']
    # the pattern should be applied only in case when the reshape operation changes number of dimensions
    if len(reshape_data_node.shape) == len(
            conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):

    if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
        reshape_data_node.out_node().type == 'FullyConnected' and \
            'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
            'need to insert Permute'.format(reshape_node.soft_get('name')))
    graph.remove_edge(conv_data_node.id, reshape_node.id)

    permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(
    new_permute_op = Permute(graph, {'order': permutation_order})
    permute_data_node = new_permute_op.create_node_with_data(
        [conv_data_node], dict(name=conv_name + '/Permute_'))
    graph.create_edge(permute_data_node, reshape_node)
    # Disable permutation for Reshape and Concat layers attributes
    PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
    reshape_node['nchw_layout'] = True
예제 #11
def tf_squeeze_infer(node):
    if node.squeeze_dims is None:
        # TODO: implement; there is no implementation now because no test

    real_squeeze_dims = []
    input_shape = node.in_node().shape
    if input_shape is None:
    # UGLY
    output_shape = input_shape.copy()
    for n in node.squeeze_dims:
        if output_shape[n] == 1:
            real_squeeze_dims.append(get_canonical_axis_index(output_shape, n))
            raise Error('Trying to squeeze dimension not equal to 1 for node "{}"'.format(node.soft_get('name')))

    output_shape = np.delete(output_shape, real_squeeze_dims)
    node.out_node().shape = output_shape

    if is_spatial_squeeze(node.graph.graph['layout'], input_shape, output_shape):
        output_shape = int64_array([0, -1])
    node['dim'] = output_shape
    if node.in_node().value is not None:
        node.out_node().value = np.array(np.reshape(node.in_node().value, output_shape))

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
예제 #12
파일: expand_dims.py 프로젝트: pc2/CustoNN2
def tf_expand_dims_infer(node):
    input_node = node.in_nodes()[0]
    output_node = node.out_node()
    if input_node.shape is None:

    # TensorFlow style with dynamic input
    if len(node.in_nodes()) > 1:
        axis_node = node.in_nodes()[1]
        if isinstance(axis_node.value, np.ndarray) and axis_node.value.size > 1:
            log.error("ExpandDims operation : axis should be scalar")
        expand_axis = axis_node.value.item()
        node.graph.remove_edge(axis_node.id, node.id)
        if not node.has_valid('expand_axis'):
            log.error("ExpandDims axis is not defined")
        expand_axis = node.expand_axis

    if expand_axis is None:

    output_node.shape = np.insert(input_node.shape, expand_axis, [1])
    # convert data type of the shape to int64 explicitly
    output_node.shape = output_node.shape.astype(np.int64)
    if input_node.value is not None:
        output_node.value = np.array(np.reshape(input_node.value, output_node.shape))

    node['dim'] = output_node.shape

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
예제 #13
def apply_nhwc_to_nchw_permutation(graph: Graph):
    # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
    if graph.graph['layout'] == 'NCHW':

    for node in graph.get_data_nodes():
        if node.has_and_set('nchw_layout'):

        # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
        permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(

        # Check that data node already has permutation
        skip_permutation = False
        for in_node in node.in_nodes():
            edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True
        for out_node in node.out_nodes():
            edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True

        if skip_permutation:

        # Set permutation to all in/out edges
        for in_node in node.in_nodes():
            PermuteAttrs.set_permutation(in_node, node, permutation)

        for out_node in node.out_nodes():
            PermuteAttrs.set_permutation(node, out_node, permutation)
예제 #14
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        axis = node.soft_get('axis', None)
        assert axis is not None

        data_shape = node.in_port(0).data.get_shape()
        assert data_shape is not None
        indices_shape = node.in_port(1).data.get_shape()
        assert indices_shape is not None

        # Convert negative axis
        axis = get_canonical_axis_index(data_shape, axis)
        node.axis = axis

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        if data_value is not None and indices_value is not None:
            node.out_port(0).data.set_value(np.array(np.take(data_value, indices_value, axis), dtype=data_value.dtype))

        shape = np.concatenate((data_shape[:axis], indices_shape))
        if axis < len(data_shape) - 1:
            shape = np.concatenate((shape, data_shape[axis + 1:]))

예제 #15
    def infer(node: Node):

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(
                0).data.get_value() is None:
                    ('shrink_axis_mask', 'input:0', permute_masks),
                    ('new_axis_mask', 'input:0', permute_masks),
                    ('ellipsis_mask', 'input:0', permute_masks),
                    ('begin_mask', 'input:0', permute_masks),
                    ('end_mask', 'input:0', permute_masks),
            for i in range(1, len(node.in_nodes())):
                if node.in_node(
                        i).value is not None and node.in_node(i).shape[0] > 3:
                    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(
                    node.in_node(i).value = permute_array_with_ellipsis(
                        node, perm,
                        node.in_node(i).value, 0)

            # due to permutation from nhwc to nchw we will extend all masks and inputs
            idx = np.nonzero(node.ellipsis_mask)
            node.ellipsis_mask[idx] = 0
예제 #16
    def infer(node: Node):
        layout = node.graph.graph['layout']

        assert len(layout) == 4
        assert len(
            [p for p in node.in_ports().values() if not p.disconnected()])
        assert node.has_valid('mode')
        assert node.has_valid('axes')

        src_shape = node.in_port(0).data.get_shape()
        assert src_shape is not None
        dst_shape = node.in_port(1).data.get_value()
        assert dst_shape is not None

        out_height = dst_shape[0]
        out_width = dst_shape[1]

        node.out_node().shape = shape_for_layout(
            batch=src_shape[get_batch_dim(layout, 4)],
            features=src_shape[get_features_dim(layout, 4)],

        PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(type='StridedSlice'):
            StridedSliceNormalizer.normalize_strided_slice(graph, node)
                     'input:0'),  # but indeed depends from slice_rank
                    ('end_mask', 'input:0'),
                    ('new_axis_mask', 'input:0'),
                    ('shrink_axis_mask', 'input:0'),
                    ('ellipsis_mask', 'input:0')

            # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
            # Until now it was not possible to set correct permutations
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:1', 'slice',
            PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                  'input:2', 'slice',
            if node.is_in_port_connected(3):
                PermuteInputs().set_input_permutation(node.in_node(3), node,
                                                      'input:3', 'slice',
예제 #18
    def infer(node):
        name = node.soft_get('name', node.id)

        op = node.soft_get('op', None)
        assert op is not None and op in ['Split', 'AttributedSplit'], \
            'Unexpected `op`={} attribute for Split-like node {}'.format(op, name)

        num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None
        assert num_in_ports in [1, 2], \
            'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \
            ''.format(num_in_ports, op, name)

        connected_inputs = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \
            "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \
            "".format(op, num_in_ports, name, connected_inputs)

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None, 'Input shape is unknown for node {}'.format(
        assert node.has_valid(
        ), 'Parameter `num_splits` is unknown for node {}'.format(name)
        num_splits = node.num_splits

        axis = node.in_port(1).data.get_value(
        ) if op == 'Split' else node.soft_get('axis', None)
        assert axis is not None, '{} `axis` is unknown for node {}'.format(
            op, name)
        assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format(
            op, name)

        assert input_shape[axis] % num_splits == 0, \
            'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \
            '`num_splits`={}'.format(op, name, input_shape, axis, num_splits)

        out_shape = input_shape.copy()
        out_shape[axis] = np.int64(input_shape[axis] / num_splits)

        input_value = node.in_port(0).data.get_value()
        output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \
            if input_value is not None else None

        for idx, port in node.out_ports().items():
            if idx in node.out_nodes():
                if output_value is not None:

        if op == 'Split':
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:0', 'axis')
        elif op == 'AttributedSplit':
                                              attrs=[('axis', 'input:0')])
예제 #19
    def find_and_replace_pattern(self, graph: Graph):
        if graph.graph['layout'] != 'NHWC':
            # we check it here because this transformation is called explicitly from the pipeline

        # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            input_shape = reinterp_shape_node.in_node(0).shape
            if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4:
                order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
                permute_node = Transpose(graph,
                                         {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'

                # do not infer the Transpose node because it should have input data node in NCHW layout (but currently
                # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
                # (which is true at this moment)
                permute_node['need_shape_inference'] = False
                # mark the Transpose output data node having correct layout so it's shape will not be permuted
                mark_output_as_in_correct_layout(permute_node, 0)

                # keep the reinterp_shape_node in NHWC layout
                mark_input_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

        # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            output_shape = reinterp_shape_node.out_node(0).shape
            if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4:
                order_const = Const(graph, {
                    'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
                permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()

                # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
                # will convert it to the NCHW
                mark_input_as_in_correct_layout(permute_node, 0)
                mark_input_as_in_correct_layout(permute_node, 1)
                # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be
                # permuted

                # keep the reinterp_shape_node in NHWC layout
                mark_output_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

                # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
                # rest of the graph consistent
                permute_node['need_shape_inference'] = False
예제 #20
    def infer(node):
        name = node.soft_get('name', node.id)
        assert node.has_valid('shape'), \
            'Parameter node {} should have `shape` attribute. Please use cli options to set model input shape' \

        PermuteAttrs.create_permute_attrs(node, attrs=[('shape', 'output:0')])
예제 #21
    def infer(node: Node):
        shape = node.in_node().shape
        if shape is None:
                "Undefined shape for the input tiles for the Tile operation '{}'."
        shape = np.copy(shape)

        if len(node.in_nodes()) == 2:
            tile_array = node.in_node(1).value
            if tile_array is None:
                log.error('A tile values are None for a node "{}".'.format(
            if len(shape) != len(tile_array):
                log.error('Shape mismatch for a node "{}": {} vs {}.'.format(
                    node.name, shape.shape, tile_array.shape))
            non_one_tile = np.argwhere(tile_array != 1)
            if len(non_one_tile) == 0:
                    'Redundant "Tile" operation "{}" with tile values for all dimensions equal to 1.'
                node['axis'] = 0
                node['tiles'] = 1
            elif len(non_one_tile) == 1:
                node['axis'] = non_one_tile[0][0]
                node['tiles'] = tile_array[node['axis']]
                node['type'] = None
                node['tile_array'] = tile_array
                    "Tile operation with more than one dimension not equal to 1 is not supported."
                # do not return here to allow infer shape and values for the constant propagation case
            node.graph.remove_edge(node.in_node(1).id, node.id)
        elif len(
        ) == 1:  # case when tiled dimension and count are specified in node attributes
            if not node.has_valid('axis') or not node.has_valid('tiles'):
                    'Mandatory attributes "axis" or "tiles" are not specified for a Tile node "{}"'
            tile_array = np.ones([len(shape)], dtype=np.int64)
            tile_array[node.axis] = node.tiles
                'Unsupported number of input parameters to Tile node "{}"'.

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
        node.out_node().shape = shape * tile_array
        if node.in_node(0).value is not None:
            node.out_node().value = np.tile(node.in_node(0).value, tile_array)
예제 #22
 def infer(node: Node):
     input_node = node.in_node(0)
     outputs = node.out_nodes()
     out_shape = copy.copy(input_node.shape)
     out_shape[node.axis] = np.int64(input_node.shape[node.axis] /
     for idx, output in outputs.items():
         output.shape = out_shape
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
예제 #23
 def infer(node: Node):
     assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
         'LogSoftmax node with id {} have more than one port connected'.format(node.id)
     if node.axis < 0:
         node.axis = len(node.in_port(0).data.get_shape()) + node.axis
     assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
         'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
예제 #24
    def infer(node: Node):
        node['order'] = list(range(node.in_node().shape.size))
        node.order[node.dim2], node.order[node.dim1] = node.order[node.dim1], node.order[node.dim2]

        input_shape = node.in_port(0).data.get_shape().copy()
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(np.transpose(node.in_port(0).data.get_value(), axes=node.order))

        PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
예제 #25
파일: reshape.py 프로젝트: pc2/CustoNN2
def tf_reshape_shape_infer(node):
    # TODO Make sure that all -1 are handled correctly
    # We cannot simply copy shape argument to the output,
    # because if -1 appears, it should be substituted by a real
    # value from input shape if input shape is completely defined.
    if node.in_node(0).shape is None:
        return None

    input_shape = node.in_node(0).shape
    reshape_output = node.in_node(1).value if len(
        node.in_nodes()) > 1 else node.dim

    if node.in_node(0).shape is None:
        return None

    total = 1
    for index, i in enumerate(input_shape):
        total *= i

    res = 1
    for index, x in enumerate(reshape_output):
        if x == 0:
            res *= input_shape[index]
        elif x != -1:
            res *= x

    new_dim = total // res
    output_shape = []
    for index, x in enumerate(reshape_output):
        if x == 0:
        elif x == -1:

    out_shape_total = 1
    for index, i in enumerate(output_shape):
        assert i != -1
        out_shape_total *= i

    if total != out_shape_total:
        raise Error(
            "Number of elements in input {} and output {} of reshape node {} mismatch"
            .format(input_shape, output_shape, node.name))

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])

    output_shape = int64_array(output_shape)

    # In case if Reshape operation was created with two inputs and dim attr wasn't set, we set in automatically
    if not node.has_valid('dim'):
        node['dim'] = output_shape

    return output_shape
예제 #26
파일: crop.py 프로젝트: pc2/CustoNN2
    def _two_inputs_infer(node: Node):
        N = len(node.in_nodes())

        shapes = [node.in_node(i).shape for i in range(N)]
        if any(s is None for s in shapes):
            log.error('Not all input shapes were defined for {} node'.format(node.name))

        if not node.has_valid('axis'):
            log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))

        if not node.has_valid('offset'):
            log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name))

        input_shape = np.array(shapes[0])
        start_axis = get_canonical_axis_index(input_shape, node.axis)
        node.axis = start_axis

        reference_shape = np.array(shapes[1])
        input_dim = input_shape.size

        # set new shape to current shape
        new_shape = input_shape.copy()
        ir_axis = []
        ir_offset = []
        dim = []

        for i in range(0, input_dim):
            if i < start_axis:
                new_shape[i] = input_shape[i]

            crop_offset = 0
            if len(node.offset) == 1:
                crop_offset = node.offset[0]
            elif len(node.offset) > 1:
                crop_offset = node.offset[i - start_axis]

            if input_shape[i] - crop_offset < reference_shape[i]:
                log.error('The crop for dimension is out of bounds in ' + node.node)

            new_shape[i] = reference_shape[i]

        node.axis = ir_axis
        node.offset = ir_offset
        node['dim'] = dim
        node.out_node().shape = new_shape
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
예제 #27
    def infer(node):
        PermuteAttrs.create_permute_attrs(node, attrs=[('pads', 'input:0')])

        num_of_inputs = len(node.in_nodes())
        if node.has_valid('pads'):
            assert num_of_inputs == 1, "Pad operation has pads attribute and unexpected additional input " \
                                       "argument for node {}.".format(node.name)
            assert num_of_inputs >= 2, "Missing required second input argument for node {} and pads attribute " \
                                       "is missing.".format(node.name)
            node['pads'] = node.in_node(1).value
            if num_of_inputs in [3, 4]:
                pads_begin = node.in_node(1).value
                pads_end = node.in_node(2).value
                node['pads'] = np.concatenate(
                    (pads_begin.reshape(-1, 1), pads_end.reshape(-1, 1)), 1)
                node['fill_value'] = node.in_node(
                    3).value if num_of_inputs == 4 else 0.0
        padding = node.pads

        input_shape = node.in_node(0).shape
        if padding is None or input_shape is None:
            log.error('The paddings are not defined for node "{}"'.format(

        # paddings can be defined, partially defined or undefined
        # TODO for now we only handle fully defined paddings
        # That means that intermediate tensor that delivers padding
        # should have defined value and size Nx2
        # TODO possible broadcasts are not supported
        assert (padding.ndim == 2 and padding.shape[1] == 2)

        # make sure that input has the same number of dimensions as the number of padding dimensions
        assert (padding.shape[0] == len(input_shape)), \
            "Input tensor shape {} and pads values {} do not match for Pad node {}".format(
                input_shape, padding.shape, node.name

        # sum low and high padding values to calculate the shape modification vector
        shape_change = np.add.reduce(padding, 1)
        assert (shape_change.shape == input_shape.shape)

        # preserve non-positive values in the input shape, because it has a special meaning
        shape = np.array([
            shape_change[i] +
            input_shape[i] if input_shape[i] > 0 else input_shape[i]
            for i in range(len(input_shape))

        assert len(node.out_nodes()) == 1

        node.out_node().shape = shape
예제 #28
    def infer(node: Node):

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, \
            'Output shape was not calculated for node {}'.format(node.name)
        # extend inputs according to ellipsis mask and/or input_shape
        for i_port in node.in_ports().values():
            if i_port.idx == 0 or i_port.disconnected():
            old_value = i_port.data.get_value()
            # additional check for non-const input
            # error will be return in shape inference if non-const will be added
            # it is paranoid check for case if shape inference will be changed
            assert old_value is not None, \
                '{} input of {} node is not constant: \'value\' attribute for edge ' + \
                'contains None'.format(i_port.idx, node.name)
            # insert 0 for begin and end and 1 for stride
            new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                   len(out_shape), list(old_value),
                                                                   int(i_port.idx == 3)))
            # set_value additionally set_shape and propagate value to Const node
            if not np.array_equal(new_value, old_value):

        # extend masks before removing ellipsis
        for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]:
            node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                    len(out_shape), list(node[attr]), 0))

        # we will extend all masks and inputs to simplify future transformations
        idx = np.nonzero(node.ellipsis_mask)
        node.ellipsis_mask[idx] = 0

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
                                                           ('new_axis_mask', 'input:0', permute_masks),
                                                           ('ellipsis_mask', 'input:0', permute_masks),
                                                           ('begin_mask', 'input:0', permute_masks),
                                                           ('end_mask', 'input:0', permute_masks),
            # permute inputs
            in_shape = node.in_port(0).get_source().data.get_shape()
            assert in_shape is not None, \
                'Input shape is unknown for 0 input of node {}'.format(node.name)
            input_rank = len(in_shape)
            if input_rank > 3:
                for i_port in node.in_ports().values():
                    if i_port.idx == 0 or i_port.disconnected():
                    new_value = permute_array(node, i_port.data.get_value())
                    # set_value additionally set_shape and propagate value to Const node
예제 #29
    def replace_pattern(self, graph: Graph, match: dict):
        if match['axis'].value is None or match['input'].shape is None:
        dims = len(match['input'].shape)
        ones = np.ones(dims, dtype=np.int64)
        axis = np.array(match['axis'].value)
        axis = axis if axis.ndim != 0 else np.array([axis], dtype=np.int64)

        mean = graph.node[match['mean'].node]
        mean['stride'] = np.array(ones)
        # TODO: need to check axis with real layout
        spatial_dims = np.array(axis)
        mean['spatial_dims'] = spatial_dims
        mean['pad'] = np.zeros((dims, 2), np.int64)
        mean['pad_spatial_shape'] = np.array(mean['pad'][spatial_dims])
        window = np.array(ones)
        window[spatial_dims] = match['input'].shape[spatial_dims]
        mean['window'] = window
        mean['TF_op'] = mean['op']
        mean['op'] = 'AvgPool'
        mean['pool_method'] = 'avg'
        mean['rounding_type'] = 'ceil'
        mean['exclude_pad'] = 'true'
        mean['kernel_spatial'] = window[spatial_dims]
        graph.remove_edge(match['axis'].node, match['mean'].node)
        mean['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[(
            'input:0'), ('stride',
                         'input:0'), ('window',
                                      'input:0'), ('spatial_dims', 'input:0')])

        if match['mean'].keep_dims == False:
            output = match['mean'].out_node()
            pool_node = match['mean']

            # Keep dims for AvgPool
            shape = np.array(output.shape)
            for idx in spatial_dims:
                shape = np.insert(shape, idx, 1)

            graph.remove_edge(pool_node.id, output.id)
            # Create new data for pool with all dims
            pool_data = Op.create_data_node(graph, pool_node,
                                            {'shape': np.array(shape)})
            # Create and connect reshape node
            reshape_op = Reshape(graph, {'dim': np.array(output.shape)})
            reshape_node = reshape_op.create_node(
                         attrs=[('dim', 'output:0')])))
            graph.create_edge(reshape_node, output)
예제 #30
    def infer(node):
        input_data_shape = node.in_port(0).data.get_shape()
        assert input_data_shape is not None
        assert node.has_valid('seq_axis')
        assert node.has_valid('batch_axis')

        assert len(node.out_nodes()) == 1

                                          attrs=[('seq_axis', 'input:0')])
                                          attrs=[('batch_axis', 'input:0')])