Exemple #1
0
def transpose_infer(node):
    if node.order is None and (not node.has_valid('reverse_order') or
                               (node.has_valid('reverse_order')
                                and node.reverse_order == False)):
        log.error('Cannot infer {} because order is None'.format(
            node.soft_get('name')))
        return

    if node.has_valid('reverse_order'
                      ) and node.reverse_order and node.has_valid('order'):
        log.error(
            'Cannot infer {} due to both order and reverse_order was set'.
            format(node.soft_get('name')))
        return

    input_shape = node.in_node(0).shape

    if node.has_valid('reverse_order') and node.reverse_order:
        node.order = np.arange(len(input_shape))[::-1]  # Reverse order

    output_shape = np.array([input_shape[i] for i in node.order],
                            dtype=np.int64)
    node.out_node(0).shape = output_shape
    if node.in_node().has_valid('value'):
        node.out_node().value = np.transpose(node.in_node().value,
                                             axes=node.order)
    PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
Exemple #2
0
    def infer(node):
        in_ports = node.in_ports()
        connected_ports = [port for port in in_ports.values() if not port.disconnected()]
        assert len(connected_ports) == 2, 'The number of inputs to the TopK layer name "{}" must be equal to 2.' \
                                          ''.format(node.soft_get('name'))

        k = node.in_port(1).data.get_value()
        if k is None:
            raise Error('The value defining number of output elements for layer "{}" is not defined'
                        ''.format(node.soft_get('name')))
        assert node.has_valid('axis'), 'The "axis" attribute is not defined for node {}'.format(node.name)

        input_shape = node.in_port(0).data.get_shape()
        node.axis = len(input_shape) + node.axis if node.axis < 0 else node.axis
        output_shape = input_shape.copy()
        output_shape[node.axis] = k

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        # setting shape and value if applicable
        if not node.out_port(0).disconnected():
            node.out_port(0).data.set_shape(output_shape)
        if not node.out_port(1).disconnected():
            node.out_port(1).data.set_shape(output_shape)
        if node.in_port(0).data.get_value() is not None:
            # TODO implement value propagation
            pass
Exemple #3
0
    def argmax_infer(node: Node):
        shape = node.in_node(0).shape
        if shape is None:
            return

        # there are two inputs in TensorFlow. The second input is the axis for ArgMax
        if len(node.in_nodes()) == 2:
            if node.in_node(1).value is None:
                log.debug('The second argument to ArgMax is None')
                return
            node.axis = node.in_node(1).value.item()
            # remove the unnecessary input
            node.graph.remove_edge(node.in_node(1).id, node.id)

        num_top_axes = shape.size
        if num_top_axes < 3:
            num_top_axes = 3

        out_shape = np.ones(num_top_axes, dtype=int)

        if node.has_valid('axis'):
            axis = get_canonical_axis_index(shape, node.axis)
            node.axis = axis
            out_shape = np.array(shape)
            out_shape[axis] = node.top_k
            PermuteAttrs.create_permute_attrs(node,
                                              attrs=[('axis', 'input:0')])
        else:
            out_shape[0] = shape[0]
            out_shape[2] = node.top_k
            if node.out_max_val:
                out_shape[1] = 2

        node.out_node().shape = out_shape
Exemple #4
0
def tf_split_infer(node):
    """
    Partial infer of split node similar to Split op of TF.
    """
    # Two inputs: [split_dim, input]
    assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name'))
    split_dim = node.in_node(0).value
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')
        return

    assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name'))
    split_dim = split_dim.item()
    input = node.in_node(1)

    if input.shape is None:
        log.error('Input shape for node {} is not defined'.format(node.soft_get('name')))
        return

    log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
    split_dim_size = input.shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")
        return

    # split_dim is a numpy array, axis is split_dim[0]
    log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(
        split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split)))
    split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
    node.graph.remove_edge(node.in_node(0).id, node.id)
    node['input_port'] = 1

    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(type='StridedSlice'):
            StridedSliceNormalizer.normalize_strided_slice(graph, node)
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('begin_mask',
                     'input:0'),  # but indeed depends from slice_rank
                    ('end_mask', 'input:0'),
                    ('new_axis_mask', 'input:0'),
                    ('shrink_axis_mask', 'input:0'),
                    ('ellipsis_mask', 'input:0')
                ])

            # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
            # Until now it was not possible to set correct permutations
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:1', 'slice',
                                                  'dim_size')
            PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                  'input:2', 'slice',
                                                  'dim_size')
            if node.is_in_port_connected(3):
                PermuteInputs().set_input_permutation(node.in_node(3), node,
                                                      'input:3', 'slice',
                                                      'dim_size')
Exemple #6
0
def tf_squeeze_infer(node):
    if node.squeeze_dims is None:
        # TODO: implement; there is no implementation now because no test
        return

    real_squeeze_dims = []
    input_shape = node.in_node().shape
    if input_shape is None:
        return
    # UGLY
    output_shape = input_shape.copy()
    for n in node.squeeze_dims:
        if output_shape[n] == 1:
            real_squeeze_dims.append(get_canonical_axis_index(output_shape, n))
        else:
            raise Error('Trying to squeeze dimension not equal to 1 for node "{}"'.format(node.soft_get('name')))

    output_shape = np.delete(output_shape, real_squeeze_dims)
    node.out_node().shape = output_shape

    if is_spatial_squeeze(node.graph.graph['layout'], input_shape, output_shape):
        output_shape = int64_array([0, -1])
    node['dim'] = output_shape
    if node.in_node().value is not None:
        node.out_node().value = np.array(np.reshape(node.in_node().value, output_shape))

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
Exemple #7
0
    def infer(node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 1 and 0 in connected_in_ports, \
            "AttributedTile should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        shape = node.in_port(0).data.get_shape()
        assert shape is not None, "Undefined input shape for AttributedTile node '{}'.".format(
            name)
        axis = node.soft_get('axis', None)
        assert axis is not None
        tiles = node.soft_get('tiles', None)
        assert tiles is not None, "Undefined `tiles` attribute of Tile node '{}'".format(
            name)

        tile_array = int64_array(np.ones(shape.size))
        tile_array[node.axis] = node.tiles

        node.out_port(0).data.set_shape(shape * tile_array)
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(
                np.tile(node.in_port(0).data.get_value(), tile_array))

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Exemple #8
0
def tf_expand_dims_infer(node):
    input_node = node.in_nodes()[0]
    output_node = node.out_node()
    if input_node.shape is None:
        return

    # TensorFlow style with dynamic input
    if len(node.in_nodes()) > 1:
        axis_node = node.in_nodes()[1]
        if isinstance(axis_node.value, np.ndarray) and axis_node.value.size > 1:
            log.error("ExpandDims operation : axis should be scalar")
            return
        expand_axis = axis_node.value.item()
        node.graph.remove_edge(axis_node.id, node.id)
    else:
        if not node.has_valid('expand_axis'):
            log.error("ExpandDims axis is not defined")
            return
        expand_axis = node.expand_axis

    if expand_axis is None:
        return

    output_node.shape = np.insert(input_node.shape, expand_axis, [1])
    # convert data type of the shape to int64 explicitly
    output_node.shape = output_node.shape.astype(np.int64)
    if input_node.value is not None:
        output_node.value = np.array(np.reshape(input_node.value, output_node.shape))

    node['dim'] = output_node.shape

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
Exemple #9
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        axis = node.soft_get('axis', None)
        assert axis is not None

        data_shape = node.in_port(0).data.get_shape()
        assert data_shape is not None
        indices_shape = node.in_port(1).data.get_shape()
        assert indices_shape is not None

        # Convert negative axis
        axis = get_canonical_axis_index(data_shape, axis)
        node.axis = axis

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        if data_value is not None and indices_value is not None:
            node.out_port(0).data.set_value(np.array(np.take(data_value, indices_value, axis), dtype=data_value.dtype))
            return

        shape = np.concatenate((data_shape[:axis], indices_shape))
        if axis < len(data_shape) - 1:
            shape = np.concatenate((shape, data_shape[axis + 1:]))

        node.out_port(0).data.set_shape(int64_array(shape))
    def infer(node: Node):
        tf_strided_slice_infer(node)

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(
                0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('shrink_axis_mask', 'input:0', permute_masks),
                    ('new_axis_mask', 'input:0', permute_masks),
                    ('ellipsis_mask', 'input:0', permute_masks),
                    ('begin_mask', 'input:0', permute_masks),
                    ('end_mask', 'input:0', permute_masks),
                ])
            for i in range(1, len(node.in_nodes())):
                if node.in_node(
                        i).value is not None and node.in_node(i).shape[0] > 3:
                    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.in_node(0).shape))
                    node.in_node(i).value = permute_array_with_ellipsis(
                        node, perm,
                        node.in_node(i).value, 0)

            # due to permutation from nhwc to nchw we will extend all masks and inputs
            idx = np.nonzero(node.ellipsis_mask)
            node.ellipsis_mask[idx] = 0
    def _one_input_infer(node: Node):
        input_shape = np.array(node.in_node().shape)

        if input_shape is None:
            log.error('input_shape is none for {} node'.format(node.name))
            return

        if not node.has_valid('axis'):
            log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
            return

        output_shape = input_shape
        if node.has_valid('dim'):
            if len(node.dim) != len(node.axis):
                log.error('number of axis should match number of dim')
                return
            output_shape[node.axis] = node.dim
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis):
                log.error('number of crop_begin/crop_end should match number of axis')
                return
            if type(node.axis) in [list, tuple]:
                for i in range(len(node.axis)):
                    output_shape[node.axis[i]] = output_shape[node.axis[i]] - node.crop_begin[i] - node.crop_end[i]
            else:
                output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end
        else:
            log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name))
            return

        node.out_node().shape = np.array(output_shape)
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Exemple #12
0
    def infer(node: Node):
        layout = node.graph.graph['layout']

        assert len(layout) == 4
        assert len(
            [p for p in node.in_ports().values() if not p.disconnected()])
        assert node.has_valid('mode')
        assert node.has_valid('axes')

        src_shape = node.in_port(0).data.get_shape()
        assert src_shape is not None
        dst_shape = node.in_port(1).data.get_value()
        assert dst_shape is not None

        out_height = dst_shape[0]
        out_width = dst_shape[1]

        node.out_node().shape = shape_for_layout(
            layout,
            batch=src_shape[get_batch_dim(layout, 4)],
            features=src_shape[get_features_dim(layout, 4)],
            height=out_height,
            width=out_width)

        PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
Exemple #13
0
    def infer(node):
        name = node.soft_get('name', node.id)

        op = node.soft_get('op', None)
        assert op is not None and op in ['Split', 'AttributedSplit'], \
            'Unexpected `op`={} attribute for Split-like node {}'.format(op, name)

        num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None
        assert num_in_ports in [1, 2], \
            'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \
            ''.format(num_in_ports, op, name)

        connected_inputs = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \
            "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \
            "".format(op, num_in_ports, name, connected_inputs)

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None, 'Input shape is unknown for node {}'.format(
            name)
        assert node.has_valid(
            'num_splits'
        ), 'Parameter `num_splits` is unknown for node {}'.format(name)
        num_splits = node.num_splits

        axis = node.in_port(1).data.get_value(
        ) if op == 'Split' else node.soft_get('axis', None)
        assert axis is not None, '{} `axis` is unknown for node {}'.format(
            op, name)
        assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format(
            op, name)

        assert input_shape[axis] % num_splits == 0, \
            'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \
            '`num_splits`={}'.format(op, name, input_shape, axis, num_splits)

        out_shape = input_shape.copy()
        out_shape[axis] = np.int64(input_shape[axis] / num_splits)

        input_value = node.in_port(0).data.get_value()
        output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \
            if input_value is not None else None

        for idx, port in node.out_ports().items():
            if idx in node.out_nodes():
                port.data.set_shape(out_shape)
                if output_value is not None:
                    port.data.set_value(output_value[idx])

        if op == 'Split':
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:0', 'axis')
        elif op == 'AttributedSplit':
            PermuteAttrs.create_permute_attrs(node,
                                              attrs=[('axis', 'input:0')])
Exemple #14
0
    def infer(node):
        name = node.soft_get('name', node.id)
        assert node.has_valid('shape'), \
            'Parameter node {} should have `shape` attribute. Please use cli options to set model input shape' \
            ''.format(name)
        node.out_port(0).data.set_shape(node.shape)

        PermuteAttrs.create_permute_attrs(node, attrs=[('shape', 'output:0')])
Exemple #15
0
 def infer(node: Node):
     input_node = node.in_node(0)
     outputs = node.out_nodes()
     out_shape = copy.copy(input_node.shape)
     out_shape[node.axis] = np.int64(input_node.shape[node.axis] /
                                     node.num_split)
     for idx, output in outputs.items():
         output.shape = out_shape
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Exemple #16
0
    def infer(node: Node):
        shape = node.in_node().shape
        if shape is None:
            log.error(
                "Undefined shape for the input tiles for the Tile operation '{}'."
                .format(node.node))
            return
        shape = np.copy(shape)

        if len(node.in_nodes()) == 2:
            tile_array = node.in_node(1).value
            if tile_array is None:
                log.error('A tile values are None for a node "{}".'.format(
                    node.name))
                return
            if len(shape) != len(tile_array):
                log.error('Shape mismatch for a node "{}": {} vs {}.'.format(
                    node.name, shape.shape, tile_array.shape))
                return
            non_one_tile = np.argwhere(tile_array != 1)
            if len(non_one_tile) == 0:
                log.info(
                    'Redundant "Tile" operation "{}" with tile values for all dimensions equal to 1.'
                    .format(node.name))
                node['axis'] = 0
                node['tiles'] = 1
            elif len(non_one_tile) == 1:
                node['axis'] = non_one_tile[0][0]
                node['tiles'] = tile_array[node['axis']]
            else:
                node['type'] = None
                node['tile_array'] = tile_array
                log.warning(
                    "Tile operation with more than one dimension not equal to 1 is not supported."
                )
                # do not return here to allow infer shape and values for the constant propagation case
            node.graph.remove_edge(node.in_node(1).id, node.id)
        elif len(
                node.in_nodes()
        ) == 1:  # case when tiled dimension and count are specified in node attributes
            if not node.has_valid('axis') or not node.has_valid('tiles'):
                log.error(
                    'Mandatory attributes "axis" or "tiles" are not specified for a Tile node "{}"'
                    .format(node.name))
                return
            tile_array = np.ones([len(shape)], dtype=np.int64)
            tile_array[node.axis] = node.tiles
        else:
            log.error(
                'Unsupported number of input parameters to Tile node "{}"'.
                format(node.name))
            return

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
        node.out_node().shape = shape * tile_array
        if node.in_node(0).value is not None:
            node.out_node().value = np.tile(node.in_node(0).value, tile_array)
Exemple #17
0
 def infer(node: Node):
     assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
         'LogSoftmax node with id {} have more than one port connected'.format(node.id)
     if node.axis < 0:
         node.axis = len(node.in_port(0).data.get_shape()) + node.axis
     assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
         'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
     copy_shape_infer(node)
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Exemple #18
0
    def infer(node: Node):
        node['order'] = list(range(node.in_node().shape.size))
        node.order[node.dim2], node.order[node.dim1] = node.order[node.dim1], node.order[node.dim2]

        input_shape = node.in_port(0).data.get_shape().copy()
        node.out_port(0).data.set_shape(input_shape[node.order])
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(np.transpose(node.in_port(0).data.get_value(), axes=node.order))

        PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
Exemple #19
0
def tf_reshape_shape_infer(node):
    # TODO Make sure that all -1 are handled correctly
    # We cannot simply copy shape argument to the output,
    # because if -1 appears, it should be substituted by a real
    # value from input shape if input shape is completely defined.
    if node.in_node(0).shape is None:
        return None

    input_shape = node.in_node(0).shape
    reshape_output = node.in_node(1).value if len(
        node.in_nodes()) > 1 else node.dim

    if node.in_node(0).shape is None:
        return None

    total = 1
    for index, i in enumerate(input_shape):
        total *= i

    res = 1
    for index, x in enumerate(reshape_output):
        if x == 0:
            res *= input_shape[index]
        elif x != -1:
            res *= x

    new_dim = total // res
    output_shape = []
    for index, x in enumerate(reshape_output):
        if x == 0:
            output_shape.append(input_shape[index])
        elif x == -1:
            output_shape.append(new_dim)
        else:
            output_shape.append(x)

    out_shape_total = 1
    for index, i in enumerate(output_shape):
        assert i != -1
        out_shape_total *= i

    if total != out_shape_total:
        raise Error(
            "Number of elements in input {} and output {} of reshape node {} mismatch"
            .format(input_shape, output_shape, node.name))

    PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])

    output_shape = int64_array(output_shape)

    # In case if Reshape operation was created with two inputs and dim attr wasn't set, we set in automatically
    if not node.has_valid('dim'):
        node['dim'] = output_shape

    return output_shape
Exemple #20
0
    def _two_inputs_infer(node: Node):
        N = len(node.in_nodes())

        shapes = [node.in_node(i).shape for i in range(N)]
        if any(s is None for s in shapes):
            log.error('Not all input shapes were defined for {} node'.format(node.name))
            return

        if not node.has_valid('axis'):
            log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name))
            return

        if not node.has_valid('offset'):
            log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name))
            return

        input_shape = np.array(shapes[0])
        start_axis = get_canonical_axis_index(input_shape, node.axis)
        node.axis = start_axis

        reference_shape = np.array(shapes[1])
        input_dim = input_shape.size

        # set new shape to current shape
        new_shape = input_shape.copy()
        ir_axis = []
        ir_offset = []
        dim = []

        for i in range(0, input_dim):
            if i < start_axis:
                new_shape[i] = input_shape[i]
                continue

            crop_offset = 0
            if len(node.offset) == 1:
                crop_offset = node.offset[0]
            elif len(node.offset) > 1:
                crop_offset = node.offset[i - start_axis]

            if input_shape[i] - crop_offset < reference_shape[i]:
                log.error('The crop for dimension is out of bounds in ' + node.node)
                return

            dim.append(reference_shape[i])
            ir_axis.append(i)
            ir_offset.append(crop_offset)
            new_shape[i] = reference_shape[i]

        node.axis = ir_axis
        node.offset = ir_offset
        node['dim'] = dim
        node.out_node().shape = new_shape
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Exemple #21
0
    def infer(node):
        PermuteAttrs.create_permute_attrs(node, attrs=[('pads', 'input:0')])

        num_of_inputs = len(node.in_nodes())
        if node.has_valid('pads'):
            assert num_of_inputs == 1, "Pad operation has pads attribute and unexpected additional input " \
                                       "argument for node {}.".format(node.name)
        else:
            assert num_of_inputs >= 2, "Missing required second input argument for node {} and pads attribute " \
                                       "is missing.".format(node.name)
            node['pads'] = node.in_node(1).value
            if num_of_inputs in [3, 4]:
                pads_begin = node.in_node(1).value
                pads_end = node.in_node(2).value
                node['pads'] = np.concatenate(
                    (pads_begin.reshape(-1, 1), pads_end.reshape(-1, 1)), 1)
                node['fill_value'] = node.in_node(
                    3).value if num_of_inputs == 4 else 0.0
        padding = node.pads

        input_shape = node.in_node(0).shape
        if padding is None or input_shape is None:
            log.error('The paddings are not defined for node "{}"'.format(
                node.soft_get('name')))
            return

        # paddings can be defined, partially defined or undefined
        # TODO for now we only handle fully defined paddings
        # That means that intermediate tensor that delivers padding
        # should have defined value and size Nx2
        # TODO possible broadcasts are not supported
        assert (padding.ndim == 2 and padding.shape[1] == 2)

        # make sure that input has the same number of dimensions as the number of padding dimensions
        assert (padding.shape[0] == len(input_shape)), \
            "Input tensor shape {} and pads values {} do not match for Pad node {}".format(
                input_shape, padding.shape, node.name
            )

        # sum low and high padding values to calculate the shape modification vector
        shape_change = np.add.reduce(padding, 1)
        assert (shape_change.shape == input_shape.shape)

        # preserve non-positive values in the input shape, because it has a special meaning
        shape = np.array([
            shape_change[i] +
            input_shape[i] if input_shape[i] > 0 else input_shape[i]
            for i in range(len(input_shape))
        ])

        assert len(node.out_nodes()) == 1

        node.out_node().shape = shape
Exemple #22
0
    def infer(node: Node):
        tf_strided_slice_infer(node)

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, \
            'Output shape was not calculated for node {}'.format(node.name)
        # extend inputs according to ellipsis mask and/or input_shape
        for i_port in node.in_ports().values():
            if i_port.idx == 0 or i_port.disconnected():
                continue
            old_value = i_port.data.get_value()
            # additional check for non-const input
            # error will be return in shape inference if non-const will be added
            # it is paranoid check for case if shape inference will be changed
            assert old_value is not None, \
                '{} input of {} node is not constant: \'value\' attribute for edge ' + \
                'contains None'.format(i_port.idx, node.name)
            # insert 0 for begin and end and 1 for stride
            new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                   len(out_shape), list(old_value),
                                                                   int(i_port.idx == 3)))
            # set_value additionally set_shape and propagate value to Const node
            if not np.array_equal(new_value, old_value):
                i_port.data.set_value(new_value)

        # extend masks before removing ellipsis
        for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]:
            node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                    len(out_shape), list(node[attr]), 0))

        # we will extend all masks and inputs to simplify future transformations
        idx = np.nonzero(node.ellipsis_mask)
        node.ellipsis_mask[idx] = 0

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
                                                           ('new_axis_mask', 'input:0', permute_masks),
                                                           ('ellipsis_mask', 'input:0', permute_masks),
                                                           ('begin_mask', 'input:0', permute_masks),
                                                           ('end_mask', 'input:0', permute_masks),
                                                           ])
            # permute inputs
            in_shape = node.in_port(0).get_source().data.get_shape()
            assert in_shape is not None, \
                'Input shape is unknown for 0 input of node {}'.format(node.name)
            input_rank = len(in_shape)
            if input_rank > 3:
                for i_port in node.in_ports().values():
                    if i_port.idx == 0 or i_port.disconnected():
                        continue
                    new_value = permute_array(node, i_port.data.get_value())
                    # set_value additionally set_shape and propagate value to Const node
                    i_port.data.set_value(new_value)
    def infer(node):
        input_data_shape = node.in_port(0).data.get_shape()
        assert input_data_shape is not None
        assert node.has_valid('seq_axis')
        assert node.has_valid('batch_axis')

        assert len(node.out_nodes()) == 1
        node.out_port(0).data.set_shape(input_data_shape)

        PermuteAttrs.create_permute_attrs(node,
                                          attrs=[('seq_axis', 'input:0')])
        PermuteAttrs.create_permute_attrs(node,
                                          attrs=[('batch_axis', 'input:0')])
Exemple #24
0
def concat_infer(node):
    if not node.has('axis'):
        N = node.N
        axis_input = node.in_node(N)
        if axis_input.has_valid('value') and axis_input.value.size == 1:
            node['axis'] = axis_input.value.item()
            node.graph.remove_edge(
                axis_input.node,
                node.node)  # TODO add skip attribute instead of deleting
        else:
            return
    else:
        N = len(node.in_nodes())

    shapes = [node.in_node(i).shape for i in range(N)]
    if any(s is None for s in shapes):
        return

    shape = np.array(shapes[0])

    axis = get_canonical_axis_index(shape, node.axis)
    node.axis = axis

    mask = np.zeros_like(shape, dtype=np.bool)
    mask[axis] = True  # pylint: disable=unsupported-assignment-operation
    not_mask = np.logical_not(mask)  # pylint: disable=assignment-from-no-return
    for s in shapes[1:]:
        s = int64_array(s)
        if np.all(shape[not_mask] ==
                  s[not_mask]):  # TODO handle -1 in a special way
            shape[mask] += s[mask]
        else:
            log.error('Concat input shapes do not match')
            return

    node.out_node(0).shape = shape
    if len(shape) != 4:
        # exclude it from NHWC to NCHW conversion
        if 'axis' in node.dim_attrs:
            node.dim_attrs.remove('axis')

    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

    values = [node.in_node(i).value for i in range(N)]
    if any(v is None for v in values):
        return

    node.out_node(0).value = np.concatenate(values, axis=node.axis).astype(
        values[0].dtype, copy=False)
    node.out_node(0).shape = np.array(node.out_node(0).value.shape,
                                      dtype=np.int64)
Exemple #25
0
def tf_transpose_infer(node):
    if len(node.in_nodes()) != 2:
        log.error("Transpose should take 2 inputs")
        return

    node_inp, node_order = (node.in_node(0), node.in_node(1))
    order = node_order.value
    in_shape = np.array(node_inp.shape)
    node.graph.remove_edge(node_order.node, node.node)
    node.order = np.array(order)
    node.out_node().shape = in_shape[order]
    if node_inp.has_valid('value'):
        node.out_node().value = np.transpose(node_inp.value, axes=order)

    PermuteAttrs.create_permute_attrs(node, attrs=[('order','input:0')])
Exemple #26
0
    def reorgyolo_infer(node: Node):
        input_shape = node.in_node(0).shape
        if input_shape is None:
            return

        stride = node.stride

        output_shape = np.full_like(input_shape, -1, dtype=np.int64)
        output_shape[node.batch_dims] = input_shape[node.batch_dims]  # pylint: disable=unsupported-assignment-operation
        output_shape[node.channel_dims] = input_shape[node.channel_dims] * stride ** 2  # pylint: disable=unsupported-assignment-operation
        # Round as in caffe
        output_shape[node.spatial_dims] = np.round(input_shape[node.spatial_dims] / stride)  # pylint: disable=unsupported-assignment-operation

        node.out_node().shape = output_shape
        PermuteAttrs.create_permute_attrs(node, attrs=[('channel_dims', 'input:0'), ('spatial_dims', 'input:0')])
Exemple #27
0
    def infer(node):
        unsqueeze_dims = np.array(node.unsqueeze_dims)
        value = node.in_node(0).value
        shape = node.in_node(0).shape

        for dim in unsqueeze_dims:
            shape = np.insert(shape, dim, 1)

        node.out_node().shape = np.array(shape)
        node['dim'] = shape
        PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])

        if value is not None:
            value = np.reshape(value, shape)
            node.out_node().value = np.array(value)
Exemple #28
0
def tf_split_v_infer(node: Node):
    """
    Partial infer of split node similar to SplitV op of TF.
    """

    if len(node.in_nodes()) == 1 and not (node.has_valid('axis')
                                          and node.has_valid('size_splits')):
        return

    if len(node.in_nodes()) == 3 and (node.has_valid('axis')
                                      or node.has_valid('size_splits')):
        return

    # Three inputs: [input, size_splits, split_dim)
    if len(node.in_nodes()) == 3:
        split_dim = node.in_node(2).value
        assert split_dim.ndim == 0
        split_dim = split_dim.item()
        size_splits = node.in_node(1).value
        node.graph.remove_edge(node.in_node(1).id, node.id)
        node.graph.remove_edge(node.in_node(2).id, node.id)
    else:
        split_dim = node.axis
        size_splits = node.size_splits

    if split_dim is None:
        log.error(
            'split_dim value for node {} is None. Cannot do shape inference.')
        return

    input = node.in_node(0)
    if input.shape is None or size_splits is None:
        log.error(
            'input shape or size of splits are not defined for node {}'.format(
                node.soft_get('name')))
        return

    log.debug(
        'split_dim = {}, input.shape = {}, size_splits.value = {}'.format(
            split_dim, input.shape, size_splits))

    # split_dim is a numpy array, axis is split_dim
    split(input, node, split_dim, size_splits)

    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
    def test_from4D_to3D(self):
        input_shape = np.array([1, 2, 3, 4])
        new_shape = np.array([3, 4, 2])
        nhwc_shape = np.array([1, 3, 4, 2])
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                       edges_with_attrs=self.edges,
                                       update_nodes_attributes=[
                                           ('input_data', {
                                               'shape': input_shape
                                           }), ('reshape', {
                                               'dim': new_shape
                                           }),
                                           ('reshape_data', {
                                               'shape': new_shape
                                           })
                                       ])
        graph.graph['layout'] = 'NHWC'
        # add permute attrs to reshape
        reshape = Node(graph, 'reshape')
        PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])

        tested_pattern = PermuteForReshape()
        tested_pattern.find_and_replace_pattern(graph)
        graph_ref = build_graph_with_attrs(
            nodes_with_attrs=self.nodes + self.permute_nodes,
            edges_with_attrs=self.edges[1:] + self.permute_edges,
            update_nodes_attributes=[('input_data', {
                'shape': input_shape
            }), ('reshape', {
                'dim': new_shape
            }), ('reshape_data', {
                'shape': new_shape
            }), ('permute_data', {
                'shape': nhwc_shape
            })])
        # check graphs equality
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      last_node='reshape_data')
        self.assertTrue(flag, resp)

        # check righ order in new permutation node
        permute_order = graph.node['reshape/Permute_']['order']
        self.assertTrue(np.all(
            permute_order == np.array([0, 2, 3, 1])))  # from NCHW to NHWC
def infer_for_opset1(node: Node):
    assert len([p for p in node.in_ports().values() if not p.disconnected()]) == 2
    assert node.has_valid('mode')
    assert node.has_valid('axes')

    src_shape = node.in_port(0).data.get_shape()

    assert src_shape is not None
    dst_shape = node.in_port(1).data.get_value()
    assert dst_shape is not None

    output_shape = src_shape.copy()
    for ind, axis in enumerate(node.axes):
        output_shape[axis] = dst_shape[ind]

    node.out_port(0).data.set_shape(output_shape)

    PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])