Пример #1
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) >= 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            'FullyConnected should have 2 connected input ports, but it doesn\'t for node: `{}`. Ports: {}' \
            ''.format(name, connected_in_ports)

        assert node.has_valid('out-size')
        input_shape = node.in_port(0).data.get_shape()
        weights_shape = node.in_port(1).data.get_shape()
        assert input_shape is not None and weights_shape is not None, \
            'Incorrect FullyConnected input shapes. Node: {}. Shapes: {}'.format(name, [input_shape, weights_shape])
        assert weights_shape.size == 2
        out_size = node.soft_get('out-size')
        assert compatible_dims(weights_shape[0], out_size), \
            'weights_shape={}, out-size={}'.format(weights_shape, out_size)

        if 2 in connected_in_ports:
            bias_value = node.in_port(2).data.get_value()
            bias_shape = node.in_port(2).data.get_shape()
            assert bias_shape is not None, 'Shape was not inferred for biases of FullyConnected {}'.format(
                name)
            assert bias_value is not None, 'Value was not inferred for biases of FullyConnected {}'.format(
                name)
            assert compatible_shapes(bias_shape, [out_size]) or compatible_shapes(bias_shape, [1, out_size]), \
                'Incorrect FullyConnected bias shape `{}` for node {}. `out-size`={}'.format(bias_shape, node, out_size)

        node.out_port(0).data.set_shape([*input_shape[:-1], out_size])
Пример #2
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        updates_shape = node.in_port(2).data.get_shape()
        assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
            'The node "{}" input shape is None'.format(node_name)

        # check that shapes are correct
        # 1. ranks of both input and indices must be at least 1
        assert len(input_shape) >= 1 and len(indices_shape) >= 1, \
            'The node "{}" input and indices ranks must be at least 1'.format(node_name)

        # 2. the last dimension of indices shape must be at most a rank of input
        assert not is_fully_defined(indices_shape[-1]) or indices_shape[-1] <= len(input_shape), \
            'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name)

        # 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:]
        # if expected updates shape is scalar, updates can be tensor with the single element (for example, of shape
        # [1], [[1]], etc.)
        expected_updates_shape = np.ma.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0)
        assert compatible_shapes(updates_shape, expected_updates_shape) or \
               (strict_compare_tensors(expected_updates_shape, []) and
                strict_compare_tensors(updates_shape, np.ones(len(updates_shape), dtype=np.int64))), \
            'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node ' \
            '"{}"'.format(node_name)

        node.out_port(0).data.set_shape(input_shape)
Пример #3
0
    def merge_infer(node: Node):
        # we infer only through executable input nodes
        inferred_nodes = [
            n for n in node.in_nodes().values() if n['is_partial_inferred']
        ]
        assert len(inferred_nodes) != 0
        tensor = inferred_nodes[0]

        if len(inferred_nodes) < len(node.in_nodes()):
            node['is_not_fully_inferred'] = True
        else:
            node['is_not_fully_inferred'] = False
            assert np.all(
                compatible_shapes(node.shape, inferred_nodes[0].shape)
                for node in inferred_nodes)

            inferred_and_executable = [
                n for n in node.in_nodes().values() if n['is_partial_inferred']
                and 'executable' in n and n['executable']
            ]
            if len(inferred_and_executable) > 0:
                tensor = inferred_and_executable[0]

                if all([
                        tensor.has_valid('value') and n.has_valid('value')
                        and strict_compare_tensors(tensor.value, n.value)
                        for n in inferred_and_executable
                ]):
                    node.out_node().value = tensor.value.copy()
                else:
                    node.out_node().value = None

        # do not use set_shape(tensor.shape) here because input port shape may be different from the calculated output
        # shape and `set_shape` will raise an error that shape has changed
        node.out_node(0).shape = shape_array(tensor.shape)
Пример #4
0
    def infer(node: Node):
        Scatter.infer(node)

        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        assert len(connected_in_ports) == 4, \
            "Incorrect number of inputs for {} node".format(node_name)

        input_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        updates_value = node.in_port(2).data.get_value()

        input_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        updates_shape = node.in_port(2).data.get_shape()

        assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
                                                       'same rank. Instead got {} and {}'.format(node_name,
                                                                                                 len(input_shape),
                                                                                                 len(indices_shape))
        assert compatible_shapes(indices_shape, updates_shape), \
            'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
            ''.format(node_name, indices_shape, updates_shape)

        axis = node.in_port(3).data.get_value()
        if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
            assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
                node_name, axis.size)
            axis = axis.item()
            out_value = input_value.copy()
            for idx in np.ndindex(*indices_shape):
                data_idx = list(idx)
                data_idx[axis] = indices_value[idx]
                out_value[tuple(data_idx)] = updates_value[idx]
            node.out_port(0).data.set_value(out_value)
Пример #5
0
def shape_inference(graph):
    for node in graph.pseudo_topological_sort():
        if node.has_and_set('need_shape_inference'):
            old_out_shapes = [
                port.data.get_shape() for port in node.out_ports().values()
                if not port.disconnected()
            ]
            node.infer(node)
            new_out_shapes = [
                port.data.get_shape() for port in node.out_ports().values()
                if not port.disconnected()
            ]
            if not node.has_and_set('override_output_shape'):
                for shape1, shape2 in zip(old_out_shapes, new_out_shapes):
                    # do not use strict shapes comparison because after applying transformation the output shape may be
                    # specialized and some dynamic dimension become static
                    if shape1 is not None and not compatible_shapes(
                            shape1, shape2):
                        raise Error(
                            "After partial shape inference were found shape collision for node {} (old shape: "
                            "{}, new shape: {})".format(
                                node.name, shape1, shape2))
            else:
                del node['override_output_shape']
            node.need_shape_inference = False
Пример #6
0
    def infer(node: Node):
        # check a number of input/output edges
        assert len(node.in_nodes()) == 3
        assert len(node.out_nodes()) == 1

        data_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        segment_ids_shape = node.in_port(2).data.get_shape()
        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        segment_ids_value = node.in_port(2).data.get_value()

        # check input shapes
        assert data_shape is not None, \
            "Shape for input data tensor to SparseSegmentSqrtN must be defined"
        assert indices_shape is not None and indices_shape.size == 1, \
            "SparseSegmentSqrtN supports only 1D indices tensor"
        assert segment_ids_shape is not None and segment_ids_shape.size == 1, \
            "SparseSegmentSqrtN supports only 1D segment IDs tensor"
        assert compatible_shapes(segment_ids_shape, indices_shape), \
            "Indices and segment IDs tensors must have compatible shapes"

        # computes output shape
        output_shape = data_shape
        output_shape[0] = segment_ids_shape[0]
        node.out_port(0).data.set_shape(output_shape)

        # infer if all input is constant
        if data_value is None or indices_value is None or segment_ids_value is None:
            return

        # check that values in segment_ids are sorted
        for i in range(1, len(segment_ids_value)):
            assert segment_ids_value[i-1] <= segment_ids_value[i], \
                "Values in segment IDs are not sorted"
        num_segments = int(segment_ids_value[-1]) + 1

        # check that indices are in a range [0, data_shape[0])
        assert np.all(indices_value >= 0) and np.all(indices_value < data_shape[0]), \
            "Some value in indices tensor is out of range"

        # infer
        num_adds = np.zeros(num_segments, dtype=np.int)
        output_value = np.zeros([num_segments] + data_shape[1:].tolist(),
                                dtype=np.float32)
        output_shape = output_value.shape
        for i in range(len(segment_ids_value)):
            segment_id = int(segment_ids_value[i])
            indice = int(indices_value[i])
            output_value[segment_id, :] += data_value[indice, :]
            num_adds[segment_id] += 1

        num_adds = np.sqrt(num_adds)
        for segment_id in range(num_segments):
            if num_adds[segment_id] != 0:
                output_value[segment_id, :] /= num_adds[segment_id]
        node.out_port(0).data.set_shape(output_shape)
        node.out_port(0).data.set_value(output_value)
Пример #7
0
    def shape_alignment(node: Node):
        """
        Specification of MatMul operation allows inputs to be aligned together before matrix multiplication.
        Current method raises an error if input shapes are not valid at any step of alignment process
        :return: aligned copies of both input shapes
        """
        node_name = node.soft_get('name', str(node.id))
        input_shapes = [node.in_port(i).data.get_shape() for i in range(2)]
        transpose_a = node.has_and_set('transpose_a')
        transpose_b = node.has_and_set('transpose_b')

        transformed_shapes = []
        for i, shape in enumerate(input_shapes):
            input_shape = shape.copy()
            # prerequisites check
            assert input_shape is not None, "MatMul has shape=`None` for {} input of `{}` node".format(
                i, node_name)
            assert input_shape.ndim == 1, "MatMul doesn't support scalar inputs. {} input of `{}` node has shape {}" \
                                          "".format(i, node_name, input_shape)
            assert input_shape.size >= 1, "MatMul doesn't support inputs with rank lower than 1. {} input of `{}` " \
                                          "node has shape {}".format(i, node_name, input_shape)
            rank = input_shape.size
            # shape alignment
            if rank != 1 and ((i == 0 and transpose_a) or
                              (i == 1 and transpose_b)):
                input_shape[-2], input_shape[-1] = input_shape[
                    -1], input_shape[-2]
            if rank == 1:
                input_shape = shape_insert(input_shape, int(i == 1), 1)

            max_shape_length = max(input_shapes[0].size, input_shapes[1].size)
            input_shape = shape_insert(input_shape, 0, [1] *
                                       (max_shape_length - input_shape.size))
            transformed_shapes.append(input_shape)

        A_shape = shape_array(transformed_shapes[0])
        B_shape = shape_array(transformed_shapes[1])

        assert A_shape.size == B_shape.size, \
            "Shapes were not aligned by length for MatMul `{}`. Shapes: `{}`".format(node_name, transformed_shapes)

        # batch broadcasting
        batch_len = A_shape.size - 2
        for i in range(batch_len):
            if A_shape[i] != B_shape[i]:
                if A_shape[i] == 1:
                    A_shape[i] = B_shape[i]
                if B_shape[i] == 1:
                    B_shape[i] = A_shape[i]

        assert compatible_shapes(A_shape[:-2], B_shape[:-2]), \
            "MatMul input shapes are incorrect. BATCH_DIMs are not equal. Node: {}. Aligned shapes: {}" \
            "".format(node_name, transformed_shapes)

        return A_shape, B_shape
Пример #8
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        indices_shape = node.in_port(1).data.get_shape()
        updates_shape = node.in_port(2).data.get_shape()
        assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
            'The node "{}" input shape is None'.format(node_name)
        assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
            'same rank. Instead got {} and {}'.format(node_name, len(input_shape), len(indices_shape))
        assert compatible_shapes(indices_shape, updates_shape), \
            'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
            ''.format(node_name, indices_shape, updates_shape)

        node.out_port(0).data.set_shape(input_shape)
Пример #9
0
def find_common_partial_shape(shape1, shape2):
    """
    Extracts maximum information from partially defined shapes.

    For example, if shape_1 = [4, dyn] and shape_2 = [dyn, 13]
    then resulting shape will be [4, 13]
    :param shape1: partially defined shape
    :param shape2: partially defined shape
    :return:
    """
    assert compatible_shapes(shape1, shape2), 'shapes must be compatible'
    res = shape1.copy()
    for i, (d1, d2) in enumerate(zip(shape1, shape2)):
        if np.ma.is_masked(res[i]):
            res[i] = d2
    return res
Пример #10
0
 def test_compare_shapes(self, input1, input2, result):
     self.assertEqual(compatible_shapes(input1, input2), result)
Пример #11
0
def _fuse_linear_sequence(graph: Graph, start_node: Node):
    """
    This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
    :param graph:
    :param start_node: The first operation of the sequence
    """
    fnodes = [start_node]
    while True:
        node = fnodes[-1]
        destinations = node.out_port(0).get_destinations()
        if len(destinations) != 1:
            break
        dst_node = destinations[0].node
        if dst_node.soft_get('op') in ['Mul', 'Add'] and get_value_in_port(dst_node) is not None and \
                dst_node.soft_get('can_be_fused') is True:
            fnodes.append(dst_node)
        else:
            break

    if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'):
        return False

    input_shape = get_tensor_in_port(start_node).data.get_shape()

    init_dims_cnt = len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1

    first_value = get_value_in_port(fnodes[0]).data.get_value()
    if not isinstance(first_value, np.ndarray):
        first_value = mo_array(first_value)
    first_value_type = first_value.dtype

    mul = np.ones([1 for x in range(init_dims_cnt)], dtype=first_value_type)
    add = np.zeros([1 for x in range(init_dims_cnt)], dtype=first_value_type)

    first_mul_name = None
    first_add_name = None

    for node in fnodes:
        const_port_value = get_value_in_port(node).data.get_value()
        if node.op == 'Mul':
            if first_mul_name is None:
                first_mul_name = node.name
            mul = mul * const_port_value
            add = add * const_port_value
        elif node.op == 'Add':
            if first_add_name is None:
                first_add_name = node.name
            add = add + const_port_value

    # If mul is scalar we broadcast it to biases shape
    if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
        mul = mo_array([mul[0] for x in range(add.shape[0])])

    assert (compatible_shapes(get_tensor_in_port(fnodes[0]).data.get_shape(), fnodes[-1].out_port(0).data.get_shape()))

    mul_op = Mul(graph, dict(name='{}/Fused_Mul_'.format(first_mul_name or '')))
    add_op = Add(graph, dict(name='{}/Fused_Add_'.format(first_add_name or '')))

    in_port = get_tensor_in_port(fnodes[0])
    out_port = fnodes[-1].out_port(0)

    """
    Four cases considered below:
        1. Mul and Add have valid values (mul value != 1 and add value != 0)
        2. Only Mul has valid values, so we add only Mul node
        3. Only Add has valid values, so we add only Add node
        4. When Mul and Add has not valid values we just merge two data nodes
    """
    if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]):
        #  Const\    Const\
        #  ----->Mul------>Add-->
        mul_const = Const(graph, dict(name="data_mul_", value=mo_array(mul))).create_node()
        add_const = Const(graph, dict(name="data_add_", value=mo_array(add))).create_node()

        mul_node = mul_op.create_node()
        add_node = add_op.create_node()

        in_port.get_connection().set_destination(mul_node.in_port(0))
        mul_const.out_port(0).connect(mul_node.in_port(1))

        mul_node.out_port(0).connect(add_node.in_port(0))
        add_const.out_port(0).connect(add_node.in_port(1))
        out_port.get_connection().set_source(add_node.out_port(0))
    elif any([x != 1 for x in np.nditer(mul)]):
        #  Const\
        #  ----->Mul-->
        mul_const = Const(graph, dict(name="data_mul_", value=mo_array(mul))).create_node()
        mul_node = mul_op.create_node()

        in_port.get_connection().set_destination(mul_node.in_port(0))
        mul_const.out_port(0).connect(mul_node.in_port(1))
        out_port.get_connection().set_source(mul_node.out_port(0))
    elif any([x != 0 for x in np.nditer(add)]):
        #  Const\
        #  ----->Add-->
        add_const = Const(graph, dict(name="data_add_", value=mo_array(add))).create_node()
        add_node = add_op.create_node()

        in_port.get_connection().set_destination(add_node.in_port(0))
        add_const.out_port(0).connect(add_node.in_port(1))
        out_port.get_connection().set_source(add_node.out_port(0))
    else:
        source_node = in_port.get_source()
        in_port.disconnect()
        out_port.get_connection().set_source(source_node)

    # Remove fused nodes
    for node in fnodes:
        graph.remove_node(node.id)

    log.debug('Fused {} operations'.format(len(fnodes)))
    return True
Пример #12
0
    def infer(node: Node):
        """
        Deconvolution has an input argument that explicitly determines output shape, so in contrast
        to the forward Conv2d we shouldn't infer output shape. We just use this output shape as
        an input shape and pass it to our utilities that computes numeric values for padding.
        They also deliver output shape that is interpreted here as input shape for convolution.
        We need to check that the real input shape and shape inferred by those utility functions match.
        """
        output_shape = shape_array(node.in_node(2).value)
        output_shape[0] = node.in_port(0).data.get_shape()[0]
        kernel_shape = node.in_port(1).data.get_shape()
        node['kernel_shape'] = kernel_shape
        if output_shape is None or kernel_shape is None or node.spatial_dims is None or node.stride is None:
            return

        if not node.has_valid('kernel_spatial_idx'):
            node['kernel_spatial_idx'] = np.delete(
                [x for x in range(len(kernel_shape))],
                (node.input_feature_channel, node.output_feature_channel))

        if not node.has_valid('dilation'):
            node['dilation'] = np.full([len(output_shape)], 1, dtype=np.int64)

        if node.has_valid('get_group'):
            node['group'] = node.get_group(node)

        spatial_dims = node.spatial_dims
        output_spatial = shape_array(output_shape[spatial_dims])
        stride_spatial = shape_array(node.stride[spatial_dims])
        node['kernel_spatial'] = shape_array(
            kernel_shape[node.kernel_spatial_idx])
        node.pad_spatial_shape, input_spatial_for_check = tf_window_op_pad_infer(
            output_spatial, node.kernel_spatial, stride_spatial, node.auto_pad)

        assert compatible_shapes(input_spatial_for_check,
                                 node.in_node(0).shape[spatial_dims])

        pad = np.zeros((len(output_shape), 2), dtype=np.int64)
        pad[spatial_dims] = node.pad_spatial_shape
        node.pad = pad

        node.output = output_shape[node.channel_dims][0]
        node.output_shape = output_shape
        node.out_port(0).data.set_shape(output_shape)

        mark_input_bins(node, ['weights'], 1)
        assign_dims_to_weights(node.in_node(1), node.kernel_spatial_idx,
                               node.input_feature_channel,
                               node.output_feature_channel, len(kernel_shape))

        # OK, now we are sure this is a supported Deconvolution layer
        node.type = 'Deconvolution'
        node.op = 'Deconv2D'

        # Add permute_attrs
        PermuteAttrs.create_permute_attrs(
            node,
            attrs=[
                ('pad', 'input:0'),
                ('stride', 'input:0'),
                ('output_shape', 'input:0'),
                ('batch_dims', 'input:0'),
                ('channel_dims', 'input:0'),
                ('spatial_dims', 'input:0'),
                ('kernel_shape', 'input:1'),
                ('kernel_spatial_idx', 'input:1'),
                ('input_feature_channel', 'input:1'),
                ('output_feature_channel', 'input:1'),
            ])

        # is needed to permute Deconv weights from the original TF [H, W, C_OUT, C_IN] into IE [C_IN, C_OUT, H, W]
        # but for other nodes in weights subgraph permutations must turned off
        # by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
        PermuteAttrs.set_permutation(
            node.in_node(1), node, node.soft_get('get_weights_permute', None))
        PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1',
                                              'transpose')
        PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0',
                                              'shape')

        node['force_precision_in_ports'] = {2: 'int64'}
Пример #13
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
            "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name)

        condition_value = node.in_port(0).data.get_value()
        condition_shape = node.in_port(0).data.get_shape()
        resulting_tensors = [
            node.in_port(1).data.get_value(),
            node.in_port(2).data.get_value()
        ]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        broadcast_rule = node.soft_get('auto_broadcast', 'numpy')

        if broadcast_rule == 'numpy':
            msg = "In Select node '{}' condition and then/else shapes must be broadcastable. " \
                  "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                    node_name, condition_shape, a_shape, b_shape)

            output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
            assert output_shape is not None, msg

            output_is_scalar = len(output_shape) == 0

            # if Select was created from TF Where operations then 1D condition must have the same size
            # as 0-index dimension of output_shape. This condition is different from being numpy compatible
            # but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
            if node.has_valid('format') and node['format'] == 'tf' and len(
                    condition_shape) == 1:
                # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598
                msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \
                         "must be matching with the first dimension of then/else branches. " \
                         "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                            node_name, condition_shape, a_shape, b_shape)

                # check equality only if both values non-dynamic
                if is_fully_defined(
                        condition_shape[0]
                ) and not output_is_scalar and is_fully_defined(
                        output_shape[0]):
                    assert condition_shape[0] == output_shape[0], msg_tf
                ones_shape = len(output_shape) if output_is_scalar else len(
                    output_shape) - 1
                condition_shape = np.concatenate(
                    (condition_shape, np.ones(ones_shape, dtype=np.int64)))

            output_shape = bi_directional_shape_broadcasting(
                output_shape, condition_shape)
            assert output_shape is not None, msg

        elif broadcast_rule == 'pdpd':
            # todo: add pdpd broadcasting rule
            # note that additionally to output_shape resulting_tensors must be broadcasted as well
            raise Error("PDPD broadcasting rule is not implemented yet")
        else:  # broadcasting is not allowed
            assert compatible_shapes(a_shape, b_shape) and compatible_shapes(condition_shape, a_shape), \
                'In node \'{}\' for Select operation when broadcasting is off all inputs must be of the same shape. ' \
                'But instead got: cond_shape={}, then_shape={}, else_shape={}'.format(
                    node_name, condition_shape, a_shape, b_shape)
            output_shape = shape_array([
                i if i is not dynamic_dimension else j
                for i, j in zip(a_shape, b_shape)
            ])

        node.out_port(0).data.set_shape(output_shape)

        if condition_value is not None:
            if is_fully_defined(condition_value) and np.all(
                    condition_value == condition_value.item(0)):
                # in some graphs Select condition is always True[False] and
                # one of the branches is None (which is not selected)
                # if we use np.where for such cases then dtype of output_value will be object (non numeric type)
                # and subsequent numpy operation on such tensors will fail
                output_value = resulting_tensors[not np.
                                                 bool(condition_value.item(0))]
                if output_value is None:
                    return
                if broadcast_rule == 'numpy':
                    output_value = bi_directional_broadcasting(
                        output_value, output_shape)
                elif broadcast_rule == 'pdpd':
                    # todo: add pdpd broadcasting rule
                    raise Error(
                        "PDPD broadcasting rule is not implemented yet")

                node.out_port(0).data.set_value(output_value)
            elif resulting_tensors[0] is not None and resulting_tensors[
                    1] is not None:
                output_value = np.ma.where(condition_value,
                                           resulting_tensors[0],
                                           resulting_tensors[1])
                node.out_port(0).data.set_value(output_value)
Пример #14
0
def eltwise_reverse_infer(node: Node):
    input_1_shape = node.in_port(0).data.get_shape()
    input_2_shape = node.in_port(1).data.get_shape()
    if input_1_shape is not None and input_2_shape is not None:
        return

    output_shape = node.out_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)

    if node['auto_broadcast'] is 'none':
        # input_1, input_2 and output shapes must match
        # therefore undefined partial shapes can be exactly defined from output shape
        if output_shape is not None:
            most_defined_shape = output_shape

            # if out_shape = [4, dyn] and input_1_shape = [dyn, 13]
            # then missing shape must be [4, 13]
            if input_1_shape is not None and not compatible_shapes(
                    output_shape, input_1_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_1_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    output_shape, input_1_shape)

            if input_2_shape is not None and not compatible_shapes(
                    output_shape, input_2_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_2_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    most_defined_shape, input_2_shape)

            if input_1_shape is None:
                node.in_port(0).data.set_shape(most_defined_shape)
            if input_2_shape is None:
                node.in_port(1).data.set_shape(most_defined_shape)
    elif node['auto_broadcast'] == 'numpy':
        if output_shape is not None:
            out_rank = len(output_shape)
            deduced_in_shape = undefined_shape_of_rank(out_rank)

            if input_1_shape is not None and input_2_shape is None and out_rank > len(
                    input_1_shape):
                in_port_to_update = 1
                defined_in_shape = input_1_shape
            elif input_2_shape is not None and input_1_shape is None and out_rank > len(
                    input_2_shape):
                in_port_to_update = 0
                defined_in_shape = input_2_shape
            else:
                return
            defined_in_rank = len(defined_in_shape)

            for i in range(-1, -defined_in_rank - 1, -1):
                assert defined_in_shape[i] == 1 or np.ma.is_masked(defined_in_shape[i]) \
                       or compatible_dims(defined_in_shape[i], output_shape[i]), \
                    "Shapes of Elementwise node '{}' are not compatible for reverse_infer.".format(node_name)

                # if defined_input_shape = [1] and output_shape = [N, 400, 400, 3]
                # partial shape information about sizes should not be lost
                if defined_in_shape[i] == 1 or output_shape[i] == 1:
                    deduced_in_shape[i] = output_shape[i]
            deduced_in_shape[:
                             -defined_in_rank] = output_shape[:
                                                              -defined_in_rank]

            node.in_port(in_port_to_update).data.set_shape(deduced_in_shape)