Ejemplo n.º 1
0
    def replace_identityN(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'data_types'), 'IdentityN {} has no `data_types` attribute'.format(
                name)
        dtypes = node.data_types

        for idx, port in node.in_ports().items():
            if not node.is_in_port_connected(
                    idx) or not node.is_out_port_connected(idx):
                # ATTENTION section in the description above
                continue
            assert idx < len(
                dtypes
            ), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(
                name, dtypes)
            identity = Identity(graph, {
                'name': '{}/{}_port'.format(name, idx),
                'data_type': dtypes[idx]
            }).create_node()
            port.get_connection().set_destination(identity.in_port(0))
            node.out_port(idx).get_connection().set_source(
                identity.out_port(0))

        # ATTENTION section in the description above
        for in_port in node.in_ports().values():
            in_port.disconnect()
        for out_port in node.out_ports().values():
            out_port.disconnect()
Ejemplo n.º 2
0
    def re_numerate_input_ports(loop_node: Node):
        """
        Update input ports ids to be consecutive from 0 to num_input_ports - 1 and update the port_map values of the
        Loop node.

        :param loop_node: the Loop node
        :return: None
        """
        def re_number_input_port(loop_node: Node, old_port_id: int,
                                 new_port_id: int):
            loop_node.add_input_port(new_port_id, skip_if_exist=True)
            loop_node.in_port(old_port_id).get_connection().set_destination(
                loop_node.in_port(new_port_id))
            Loop.update_port_map_value(loop_node.input_port_map,
                                       'external_port_id', old_port_id,
                                       new_port_id)

        if len(loop_node.in_ports()) > 0:
            max_port_id = sorted(loop_node.in_ports().keys())[-1]
            new_port_id = 0
            for port_id in range(max_port_id + 1):
                if port_id in loop_node.in_ports():
                    if port_id != new_port_id:
                        re_number_input_port(loop_node, port_id, new_port_id)
                    new_port_id += 1

            for port_idx_to_remove in reversed(
                    range(new_port_id, max_port_id + 1)):
                if port_idx_to_remove in loop_node.in_ports().keys():
                    loop_node.delete_input_port(port_idx_to_remove)
Ejemplo n.º 3
0
def compute_unsqueeze_map_for_eltwise(eltwise_node: Node):
    '''
    The function computes a map of unsqueeze_dims for each producer of eltwise node.
    These unsqueeze_dims are needed to normalize input shapes of eltwise node.
    '''
    eltwise_shape = eltwise_node.out_port(0).data.get_shape()
    max_dims = max(
        [len(port.data.get_shape()) for port in eltwise_node.in_ports().values() if port.data.get_shape() is not None])
    axis = eltwise_node.soft_get('axis', None)
    unsqueeze_dims_map = {}
    for consumer_port in eltwise_node.in_ports().values():
        producer_port = consumer_port.get_source()
        producer_shape = producer_port.data.get_shape()
        unsqueeze_dims = int64_array([])

        # 1. Compute unsqueeze dimensions in the tail
        if len(producer_shape) != max_dims and len(producer_shape) > 0 and axis is not None:
            num_unsqueeze_dims = max_dims - axis - len(producer_shape)
            if num_unsqueeze_dims > 0:
                unsqueeze_dims = np.arange(len(producer_shape), len(producer_shape) + num_unsqueeze_dims,
                                           dtype=np.int64)

        # 2. Compute unsqueeze dimensions in the head
        unsqueeze_dims_head = np.arange(len(eltwise_shape) - len(producer_shape) - len(unsqueeze_dims), dtype=np.int64)

        # Pay attention that unsqueeze dims order makes sense
        # since shape is normalized in the tail first and after in the head
        unsqueeze_dims = np.concatenate((unsqueeze_dims, unsqueeze_dims_head))
        unsqueeze_dims_map[producer_port] = unsqueeze_dims

    return unsqueeze_dims_map
Ejemplo n.º 4
0
    def replace_op(self, graph: Graph, node: Node):
        out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports())}).create_node()
        pack_name = node.soft_get('name', node.id)

        for ind in node.in_ports():
            unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array([node.axis])},
                                                         {'name': node.soft_get('name', node.id) + '/Unsqueeze'})
            node.in_port(ind).get_connection().set_destination(unsqueeze_node.in_port(0))
            unsqueeze_node.out_port(0).connect(out_node.in_port(ind))

        rename_nodes([(node, pack_name + '/TBR'), (out_node, pack_name)])
        return [out_node.id]
Ejemplo n.º 5
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        assert len(connected_in_ports) in [4, 5], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        logit_length_shape = node.in_port(1).data.get_shape()
        labels_shape = node.in_port(2).data.get_shape()
        label_length_shape = node.in_port(3).data.get_shape()
        blank_index_shape = int64_array([])
        if len(node.in_nodes()) == 5:
            blank_index_shape = node.in_port(4).data.get_shape()

        # check shapes of input tensors
        assert len(logits_shape) == 3 and len(logit_length_shape) == 1 and len(labels_shape) == 2\
            and len(label_length_shape) == 1 and len(blank_index_shape) == 0, \
            'Incorrect rank of some input tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], logit_length_shape[0]) and \
               compatible_dims(logits_shape[0], labels_shape[0]) and \
               compatible_dims(logits_shape[0], label_length_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
        assert compatible_dims(logits_shape[1], labels_shape[1]), \
            'Time dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        node.out_port(0).data.set_shape([batch_size])
Ejemplo n.º 6
0
    def infer(self, node: Node):
        node_name = node.soft_get(node.name, node.id)
        assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [2, 3], \
            '(I)DFT node {} must have 2 or 3 inputs'.format(node_name)

        src_shape = node.in_port(0).data.get_shape()
        assert src_shape is not None, 'The input data shape of (I)DFT node {} must not be None'.format(
            node_name)
        assert src_shape[-1] == 2, \
            'The last dimension of input shape of (I)DFT node {} should be equal to 2'.format(node_name)

        input_rank = len(src_shape)
        assert input_rank >= 2, 'The input rank of (I)DFT node {} should be greater or equal to 2'.format(
            node_name)

        axes = FFTBase.get_axes(node)
        assert input_rank >= len(axes) + 1, \
            'The input rank must be greater than number of (I)DFT node {} axes'.format(node_name)
        axes = FFTBase.canonicalize_axes(axes, input_rank)
        assert (
            input_rank - 1
        ) not in axes, '(I)DFT node {} axes cannot contain the last axis'.format(
            node_name)
        assert len(set(axes)) == len(
            axes), '(I)DFT node {} axes must be unique.'.format(node_name)

        output_shape = src_shape.copy()
        if node.is_in_port_connected(2):
            signal_size = FFTBase.get_signal_size(node)
            signal_size = FFTBase.canonicalize_signal_size(
                signal_size, axes, src_shape)
            output_shape[axes] = signal_size

        node.out_port(0).data.set_shape(output_shape)
    def get_non_interpolate_concat_sources(self, concat: Node):
        """
        Traverses Concat input ports up to find which of them are not connected to Interpolate operations directly
        or through identity operation sequence. Returns the list of Concat sources that satisfy the condition.
        """
        assert concat.soft_get('type') == 'Concat'
        sources, ports_to_omit = [], []
        if concat.has_valid('N'):
            # TODO: should be removed after Concat operation normalization
            ports_to_omit.append(concat.N)

        for in_port in concat.in_ports().values():
            if in_port.disconnected() or in_port.idx in ports_to_omit:
                continue
            next_node = in_port.get_source().node
            while next_node.soft_get(
                    'type') != 'Interpolate' and next_node.has_and_set(
                        'identity'):
                node = self.get_single_input_source_safely(next_node)
                if node is not None:
                    next_node = node
                else:
                    break
            if next_node.soft_get('type') != 'Interpolate':
                sources.append(in_port.get_connection().get_source())
        return sources
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) in [2, 3], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        sequence_len_shape = node.in_port(1).data.get_shape()
        if len(node.in_nodes()) == 3:
            blank_index_shape = node.in_port(2).data.get_shape()
            assert len(blank_index_shape) == 1, \
                'Incorrect rank of blank_index for {} node'.format(node_name)

        # check shapes of input tensors
        assert len(logits_shape) == 3, \
            'Incorrect rank of logits for {} node'.format(node_name)

        assert len(sequence_len_shape) == 1, \
            'Incorrect rank of sequence length tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], sequence_len_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        time_size = logits_shape[1]
        if node.is_out_port_connected(0):
            node.out_port(0).data.set_shape([batch_size, time_size])
        if node.is_out_port_connected(1):
            node.out_port(1).data.set_shape([batch_size])
Ejemplo n.º 9
0
def arg_ops_infer(node: Node):
    shape = node.in_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)
    assert shape is not None, "Input shape for the node {} is None".format(node_name)

    # there are two inputs in TensorFlow. The second input is the axis for ArgMax
    connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
    if len(connected_in_ports) == 2:
        axis = node.in_port(1).data.get_value()
        if axis is None:
            log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id)))
            return
        node.axis = axis
        # remove the unnecessary input
        node.in_port(1).disconnect()

    num_top_axes = shape.size
    if num_top_axes < 3:
        num_top_axes = 3

    out_shape = np.ones(num_top_axes, dtype=np.int64)

    if node.has_valid('axis'):
        axis = get_canonical_axis_index(shape, node.axis)
        node.axis = axis
        out_shape = shape.copy()
        out_shape[axis] = node.top_k
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
    else:
        out_shape[0] = shape[0]
        out_shape[2] = node.top_k
        if node.has_and_set('out_max_val'):
            out_shape[1] = 2

    node.out_port(0).data.set_shape(out_shape)
Ejemplo n.º 10
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) == 2, \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        sequence_mask_shape = node.in_port(1).data.get_shape()

        # check shapes of input tensors
        assert len(logits_shape) == 3, \
            'Incorrect rank of logits for {} node'.format(node_name)
        assert len(sequence_mask_shape) == 2, \
            'Incorrect rank of sequence length tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[1], sequence_mask_shape[1]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], sequence_mask_shape[0]), \
            'Time dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[1]
        time_size = logits_shape[0]
        node.out_port(0).data.set_shape([batch_size, time_size, 1, 1])
Ejemplo n.º 11
0
    def tf_resize_infer(node: Node):
        input_shape = node.in_port(0).data.get_shape()
        if input_shape is None:
            return

        attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
                    "the attribute align_corners must be False"
        node_name = node.soft_get('name', node.id)
        assert not node.half_pixel_centers or (node.half_pixel_centers and not node.align_corners), \
            attrs_msg.format(node_name, node.op)

        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        assert len(connected_in_ports) == 2, \
            "Node {} with op {} number of inputs must be equal to 2.".format(node_name, node.op)

        new_sizes_value = node.in_port(1).data.get_value()
        assert new_sizes_value is not None, "Node {} with op {} has no value in input port 1".format(node_name, node.op)

        input_rank = len(input_shape)
        assert input_rank == 4, \
            "Resized input data of the node {} with op {} must be 4D tensor".format(node_name, node.op)

        len_msg = "Op {} with name {} supports only resize with respect to height and width dimension simultaneously"
        assert len(new_sizes_value) == 2, len_msg.format(node_name, node.op)

        output_shape = input_shape.copy()

        output_shape[1] = new_sizes_value[0]
        output_shape[2] = new_sizes_value[1]

        node.out_port(0).data.set_shape(output_shape)
Ejemplo n.º 12
0
    def infer(node: Node):
        Scatter.infer(node)

        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) == 4, \
            "Incorrect number of inputs for {} node".format(node_name)

        input_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        indices_shape = node.in_port(1).data.get_shape()
        updates_value = node.in_port(2).data.get_value()
        axis = node.in_port(3).data.get_value()
        if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
            assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
                node_name, axis.size)
            axis = axis.item()
            out_value = input_value.copy()
            for idx in np.ndindex(*indices_shape):
                data_idx = list(idx)
                data_idx[axis] = indices_value[idx]
                out_value[tuple(data_idx)] = updates_value[idx]
            node.out_port(0).data.set_value(out_value)
Ejemplo n.º 13
0
    def infer(node: Node):
        # order parameter calculation and checks
        in_ports = node.in_ports()
        connected_ports = [
            port for port in in_ports.values() if not port.disconnected()
        ]
        input_shape = node.in_port(0).data.get_shape()

        if node.has_and_set('reverse_order'):
            assert len(connected_ports) == 1 and 0 in in_ports, \
                'Cannot infer `{}` due to both order and reverse_order was set'.format(node.soft_get('name'))
            order = np.arange(len(input_shape))[::-1]  # Reverse order
        else:
            # we import PermuteInputs locally because it uses Transpose inside and we have recursive imports
            from openvino.tools.mo.graph.perm_inputs import PermuteInputs
            assert len(connected_ports) == 2 and 0 in in_ports and 1 in in_ports, \
                "{} node `{}` should have 2 input ports, where 0-input is a data input and 1-input represents " \
                "Transpose `order`".format(node.op, node.id)
            order = node.in_port(1).data.get_value()
            assert order is not None, 'Cannot infer `{}` because order is None'.format(
                node.soft_get('name'))
            PermuteInputs().set_input_permutation(node.in_node(1), node,
                                                  'input:0', 'order')

        # setting shape and value if applicable
        if node.in_port(0).data.get_value() is not None:
            node.out_port(0).data.set_value(
                np.transpose(node.in_port(0).data.get_value(), axes=order))
        else:
            node.out_port(0).data.set_shape(input_shape[order])
Ejemplo n.º 14
0
    def infer(self, node: Node):
        node_name = node.soft_get(node.name, node.id)
        assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [2, 3], \
            'RDFT node {} must have 2 or 3 inputs'.format(node_name)

        src_shape = node.in_port(0).data.get_shape()
        assert src_shape is not None, 'The input data shape of RDFT node {} must not be None'.format(
            node_name)

        input_rank = len(src_shape)
        assert input_rank >= 1, 'The input rank of RDFT node {} should be greater or equal to 1'.format(
            node_name)

        axes = RDFT.get_axes(node)
        assert input_rank >= len(axes), \
            'The input rank must be greater than or equal to number of RDFT node {} axes'.format(node_name)
        axes = RDFT.canonicalize_axes(axes, input_rank)
        assert len(set(axes)) == len(
            axes), 'RDFT node {} axes must be unique.'.format(node_name)

        output_shape = src_shape.copy()
        if node.is_in_port_connected(2):
            signal_size = RDFT.get_signal_size(node)
            signal_size = RDFT.canonicalize_signal_size(
                signal_size, axes, src_shape)
            output_shape[axes] = signal_size
        output_shape[axes[-1]] = output_shape[axes[-1]] // 2 + 1
        output_shape = np.ma.append(output_shape, 2)

        node.out_port(0).data.set_shape(output_shape)
def align_frame_time(graph: Graph, node: Node, frame_time_max):
    for inp in node.in_ports():
        if node.in_port(inp).disconnected():
            continue
        in_node = node.in_port(inp).get_source().node
        in_node_out_port = node.in_port(inp).get_source()
        in_port = node.in_port(inp)
        # Adding MemoryOffset for Const does not make sense
        if in_node.frame_time < frame_time_max and in_node.op != 'Const':
            # Change existing MemoryOffset to avoid adding new one
            if in_node.op == 'MemoryOffset':
                in_node.t = in_node.frame_time - frame_time_max
                in_node.frame_time = in_node.t
            else:
                mem_name = graph.unique_id("align_" + node.id)
                memory_align = MemoryOffset(
                    graph,
                    attrs={
                        'id': mem_name,
                        'name': mem_name,
                        'pair_name': mem_name + "_pair",
                        't': in_node.frame_time - frame_time_max,
                        'splitted': False
                    }).create_node()
                # add element_size for MemoryOffset after Parameter for infer
                if in_node.op == 'Parameter':
                    memory_align['element_size'] = in_node.shape
                in_port.get_connection().set_source(memory_align.out_port(0))
                memory_align.in_port(0).connect(in_node_out_port)
                memory_align['frame_time'] = memory_align.t
        # remove MemoryOffset with maximum delay
        elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset':
            in_node_out_port.get_connection().set_source(
                in_node.in_port(0).get_source())
            graph.remove_node(in_node.id)
Ejemplo n.º 16
0
    def infer(node: Node):
        assert [port.idx for port in node.in_ports().values() if not port.disconnected()] == [0], \
            'Wrong input nodes number for node {} with type ExtractImagePatches'.format(node.soft_get('name', node.id))
        input_shape = node.in_port(0).data.get_shape()
        name = node.soft_get('name', node.id)
        assert input_shape is not None, 'Input shape is not set for node {} with type ExtractImagePatches'.format(
            name)

        assert len(
            input_shape
        ) == 4, 'ExtractImagePatches operation supports only 4D tensors'

        layout = node.graph.graph['layout']
        N = input_shape[get_batch_dim(layout, 4)]
        C = input_shape[get_features_dim(layout, 4)]

        size_spatial = shape_array(node.sizes)[node.spatial_dims]

        input_spatial_shape = input_shape[node.spatial_dims]
        stride_spatial_shape = node.strides[node.spatial_dims]

        size_extent = node.rates[node.spatial_dims] * (size_spatial - 1) + 1

        pad_spatial_shape, output_spatial_shape = tf_window_op_pad_infer(
            input_spatial_shape, size_extent, stride_spatial_shape,
            node.auto_pad, False)

        out_shape = shape_for_layout(layout,
                                     batch=N,
                                     features=C * np.prod(size_spatial),
                                     height=output_spatial_shape[0],
                                     width=output_spatial_shape[1])

        node.out_port(0).data.set_shape(out_shape)
Ejemplo n.º 17
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) >= 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            'FullyConnected should have 2 connected input ports, but it doesn\'t for node: `{}`. Ports: {}' \
            ''.format(name, connected_in_ports)

        assert node.has_valid('out-size')
        input_shape = node.in_port(0).data.get_shape()
        weights_shape = node.in_port(1).data.get_shape()
        assert input_shape is not None and weights_shape is not None, \
            'Incorrect FullyConnected input shapes. Node: {}. Shapes: {}'.format(name, [input_shape, weights_shape])
        assert weights_shape.size == 2
        out_size = node.soft_get('out-size')
        assert compatible_dims(weights_shape[0], out_size), \
            'weights_shape={}, out-size={}'.format(weights_shape, out_size)

        if 2 in connected_in_ports:
            bias_value = node.in_port(2).data.get_value()
            bias_shape = node.in_port(2).data.get_shape()
            assert bias_shape is not None, 'Shape was not inferred for biases of FullyConnected {}'.format(
                name)
            assert bias_value is not None, 'Value was not inferred for biases of FullyConnected {}'.format(
                name)
            assert compatible_shapes(bias_shape, [out_size]) or compatible_shapes(bias_shape, [1, out_size]), \
                'Incorrect FullyConnected bias shape `{}` for node {}. `out-size`={}'.format(bias_shape, node, out_size)

        node.out_port(0).data.set_shape([*input_shape[:-1], out_size])
Ejemplo n.º 18
0
    def lift_up_through_eltwise(node: Node, reverse_channels: Node):
        r"""
        BEFORE                      AFTER

                                    previous_op              previous_op'
                                          \                    /
        previous_op  previous_op'     ReverseChannels     ReverseChannels
                 \     /                            \     /
                Eltwise                             Eltwise
                   |                                  |
             ReverseChannels                       next_op
                  |
                next_op

        returns two objects:
        first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
        second - list of new ReverseChannels operations that were produced while propagating reverse_channels up
        """
        before_shape = reverse_channels.in_port(0).data.get_shape()

        port_axis = []
        for idx, port in node.in_ports().items():
            shape = port.data.get_shape()

            non_one_dims = np.where(shape != 1)[0]
            if shape[reverse_channels.axis] == 1:
                continue  # nothing to flip for this input
            if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size:
                axis = non_one_dims.item()
            elif np.array_equal(before_shape, shape):
                axis = reverse_channels.axis
            else:
                # shape has multiple non-one values and shape is not fully broadcasted to value port shape
                # it is safe not to propagate reverse channels
                return False, []
            port_axis.append((port, axis))

        copies = []
        for port, axis in port_axis:
            reverse_channels_copy = reverse_channels.copy_node({'axis': mo_array(axis)})

            src = port.get_connection().get_source()
            if src.node.soft_get('type') == 'Parameter':
                # For Parameter nodes tensor debug attributes should not move to the last node
                # of subgraph. It is needed for the proper mapping of input framework name.
                # For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
                port.get_connection().set_source(reverse_channels_copy.out_port(0), attributes_save_mode="source")
            else:
                port.get_connection().set_source(reverse_channels_copy.out_port(0))
            src.connect(reverse_channels_copy.in_port(0))

            copies.append(reverse_channels_copy)

        reverse_channels.out_port(0).get_connection().set_source(
            reverse_channels.in_port(0).get_connection().get_source())
        reverse_channels.in_port(0).disconnect()

        # propagated reverse_channels successfully through current node, will continue propagation
        return True, copies
Ejemplo n.º 19
0
def infer_for_opset4(node: Node):
    assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [3, 4], \
        "Interpolate-4 node {} must have 3 or 4 inputs".format(node.soft_get(node.name, node.id))
    assert node.has_valid('mode')
    assert node.has_valid('shape_calculation_mode')
    src_shape = node.in_port(0).data.get_shape()
    assert src_shape is not None

    input_rank = len(src_shape)

    pads_begin = correct_pad(node.soft_get('pads_begin', [0]), input_rank)
    pads_end = correct_pad(node.soft_get('pads_end', [0]), input_rank)
    node['pads_begin'] = pads_begin
    node['pads_end'] = pads_end

    if len(node.in_ports()) == 3:
        axes = list(range(0, input_rank))
    else:
        axes = node.in_port(3).get_source().data.get_value()
        assert axes is not None, \
            "Interpolate-4 node with name {} has None as 'axes' input".format(node.soft_get('name', node.id))

    axes = int64_array(axes)
    output_shape = src_shape + pads_begin + pads_end
    if node.shape_calculation_mode == 'sizes':
        dst_shape = node.in_port(1).data.get_value()
        assert dst_shape is not None
        correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
        for i, axis in enumerate(axes):
            output_shape[axis] = dst_shape[i]
    else:
        scales = node.in_port(2).data.get_value()
        assert scales is not None
        for i, axis in enumerate(axes):
            if output_shape[axis] is not dynamic_dimension and scales[
                    i] is not dynamic_dimension:
                output_shape[axis] = math.floor(scales[i] *
                                                output_shape[axis] + 1.0e-5)
            else:
                output_shape[axis] = dynamic_dimension_value

    if node.is_in_port_connected(3):
        PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0',
                                              'axis')

    node.out_port(0).data.set_shape(output_shape)
Ejemplo n.º 20
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        num_inputs = len(connected_in_ports)
        assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name)
        equation = node.equation

        # parse the equation and extract input and output subscripts
        input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)

        # check that each operand has the corresponding input subscript
        assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \
                                                    "must match the number of input subscripts " \
                                                    "in `equation`".format(node_name)

        # check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels
        label_to_shape = {}
        for input_ind in range(num_inputs):
            input_shape = node.in_port(input_ind).data.get_shape()
            input_subscript = input_subscripts[input_ind]
            labels = Einsum.extract_subscript_labels(node_name, input_subscript)
            num_dims = len(input_shape)
            num_labels = len(labels)
            num_broadcasted_dims = num_dims - num_labels + 1
            dim_ind = 0
            label_ind = 0
            while label_ind < num_labels and dim_ind < num_dims:
                label = labels[label_ind]
                if label == "...":
                    sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims]
                    if label in label_to_shape.keys():
                        common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label])
                        assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \
                                                         "for Einsum node {}".format(node_name)
                        label_to_shape[label] = common_shape
                    else:
                        label_to_shape[label] = sub_shape
                    dim_ind += num_broadcasted_dims
                else:
                    dim_size = input_shape[dim_ind]
                    sub_shape = shape_array([dim_size])
                    assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \
                        "Sizes of dimensions with the same label of Einsum node {} " \
                        "must be compatible".format(node_name)
                    label_to_shape[label] = sub_shape
                    dim_ind += 1
                label_ind += 1

        # generate output shape based on the output subscript
        output_shape = shape_array([])
        labels = Einsum.extract_subscript_labels(node_name, output_subscript)
        for label in labels:
            assert label in label_to_shape.keys(), "The label in the output subscript must appear" \
                                                   " in input subscripts in equation {} " \
                                                   "of Einsum node {}".format(equation, node_name)
            output_shape = np.ma.concatenate((output_shape, label_to_shape[label]))

        node.out_port(0).data.set_shape(output_shape)
Ejemplo n.º 21
0
def get_node_input_ports(node: Node):
    """
    Return list of node input nodes with their ports
    Indexes of input nodes in list matches number of input port for this inputs
    :param node: node from NetworkX to get inputs
    :return: list of node inputs
    """
    sources_ports = [parent.get_source() for parent in node.in_ports().values()]
    return [port for port in sources_ports if port is not None]
Ejemplo n.º 22
0
 def get_single_input_source_safely(node: Node, idx: int = 0):
     """
     Checks if node has exactly one used input port
     If the check passed, function returns input_node otherwise None
     """
     connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
     if len(connected_in_ports) == 1 and connected_in_ports[0].idx == idx:
         return node.in_port(idx).get_source().node
     return None
Ejemplo n.º 23
0
 def infer(node: Node):
     assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
         'LogSoftmax node with id {} have more than one port connected'.format(node.id)
     if node.axis < 0:
         node.axis = len(node.in_port(0).data.get_shape()) + node.axis
     assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
         'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
     copy_shape_infer(node)
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Ejemplo n.º 24
0
def get_input_shape(node: Node, in_port: int):
    """
    Return shape of in_port input of node
    :param node: NetworkX node to get input shape
    :param in_port: input port number
    :return:
    """
    if in_port not in node.in_ports():
        raise Exception('Can\'t get shape for {} port of {} node. No such port in node'.format(in_port, node.name))
    in_port = node.in_port(in_port)
    return in_port.data.get_shape()
def set_reshape_new_output_shape(reshape_node: Node, new_output_shape: np.array):
    """
    Updates Reshape node shape to a new output shape. The function updates the second input if the node has it.
    :param reshape_node: node to update
    :param new_output_shape: new output shape
    :return: None
    """
    reshape_node.out_port(0).data.set_shape(new_output_shape)
    in_ports = [port for port in reshape_node.in_ports().values() if not port.disconnected()]
    if len(in_ports) == 2:
        reshape_node.in_port(1).data.set_value(new_output_shape)
Ejemplo n.º 26
0
def reduce_infer(node: Node):
    connected_in_ports = [
        port for port in node.in_ports().values() if not port.disconnected()
    ]
    assert len(connected_in_ports) == 2, \
        "{} node `{}` should have 2 input ports, where 0-input is data input and 1-input represent " \
        "`reduction_indices`".format(node.op, node.id)

    in_data = node.in_port(0).data
    in_shape = in_data.get_shape()
    axis = node.in_port(1).data.get_value()

    # If the axis is None then reduce over all the dimensions of the input tensor
    if axis.size == 1 and axis.item() is None:
        axis = int64_array(list(range(len(in_shape))))
        node.in_port(1).data.set_value(axis)

    assert in_shape is not None, "Can not infer {} node `{}`: shape of 0-input unknown".format(
        node.op, node.id)

    axis = axis.copy()
    if axis.size == 1:
        axis = int64_array([axis.item()])

    in_value = in_data.get_value()

    if in_value is not None:
        value = reduce_helper(reduce_map[node.op],
                              in_value.copy(),
                              axis=tuple(axis),
                              keepdims=node.keep_dims)
        node.out_port(0).data.set_value(value)
    else:
        used_dims = np.zeros(len(in_shape), dtype=np.bool)
        output_shape = in_shape.copy()

        for dim in axis:
            used_dims[dim] = True
            output_shape[dim] = 1

        # In case if keep dims == False, we should remove all 1 dims that was used in reduction
        if not node.keep_dims:
            output_shape = output_shape[np.invert(used_dims)]

        node.out_port(0).data.set_shape(output_shape)

    # if the operation changes the rank of the output tensor then it is necessary to insert Permute if the input is 4D
    # or 5D
    if not node.keep_dims:
        node['reinterp_shape'] = True

    PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0',
                                          'axis')
def find_max_frame_time(node: Node):
    in_frame_time_max = 0
    should_align = False
    for inp in node.in_ports():
        if node.in_port(inp).disconnected():
            continue
        in_node = node.in_port(inp).get_source().node
        if in_node.frame_time > in_frame_time_max:
            in_frame_time_max = in_node.frame_time

    if in_frame_time_max == 0:
        return in_frame_time_max, False

    for inp in node.in_ports():
        if node.in_port(inp).disconnected():
            continue
        if in_frame_time_max != node.in_port(inp).get_source().node.frame_time:
            should_align = True
            break

    return in_frame_time_max, should_align
Ejemplo n.º 28
0
 def generate_port_map(node: Node, src_port_map, dir: str):
     result_list = []
     for record in src_port_map:
         # do not update ids for not-connected output which is used in the Loop operation only
         if record['external_port_id'] != -1:
             if dir == 'out':  # increase the output port id by the number of input ports
                 # update the port id for proper generation of a "ports" section
                 record['external_port_id'] += len(node.in_ports())
         record[
             'internal_layer_id'] = TensorIterator.find_internal_layer_id(
                 node.body, record['internal_layer_id'])
         result_list.append(record)
     return result_list
Ejemplo n.º 29
0
    def reverse_infer(node: Node):
        assert hasattr(node, 'axis')
        out_shape = node.out_port(0).data.get_shape()

        if out_shape is None:
            return

        out_shape[node.axis] = dynamic_dimension

        for in_port in node.in_ports().values():
            in_shape = in_port.data.get_shape()
            if in_shape is None:
                in_port.data.set_shape(out_shape)
def mark_as_correct_data_layout(node: Node):
    """
    The analogue of the attribute 'correct_data_layout' for the operation node
    :param node: node to mark it with attribute 'correct_data_layout'
    :return: None
    """
    assert node.soft_get(
        'kind') == 'op', 'The function work with operation nodes only'
    for ind, port in node.in_ports().items():
        mark_input_as_in_correct_layout(node, ind)

    for ind, port in node.out_ports().items():
        mark_output_as_in_correct_layout(node, ind)