Beispiel #1
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        input_value = node.in_port(0).data.get_value()
        target_shape = node.in_port(1).data.get_value()
        assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(
            node_name)
        assert node.has_and_set(
            'mode'), 'Broadcasting mode is not defined for node "{}"'.format(
                node_name)

        if node.mode == 'numpy':
            node.out_port(0).data.set_shape(
                uni_directional_shape_broadcasting(input_shape, target_shape))
        elif node.mode == 'bidirectional':
            node.out_port(0).data.set_shape(
                bi_directional_shape_broadcasting(input_shape, target_shape))
        else:
            raise Error('The node "{}" has unsupported mode "{}"'.format(
                node_name, node.mode))

        PermuteInputs().set_input_permutation(node.in_node(1), node,
                                              'output:0', 'shape')

        if input_value is not None and not node.has_and_set(
                'stop_value_propagation'):
            if node.mode == 'numpy':
                node.out_port(0).data.set_value(
                    uni_directional_broadcasting(input_value, target_shape))
            elif node.mode == 'bidirectional':
                node.out_port(0).data.set_value(
                    bi_directional_broadcasting(input_value, target_shape))
Beispiel #2
0
 def test_bi_directional_shape_broadcasting(self, input_shape, target_shape,
                                            expected_shape):
     result = bi_directional_shape_broadcasting(input_shape, target_shape)
     if expected_shape is None:
         self.assertIsNone(result)
     else:
         self.assertTrue(strict_compare_tensors(result, expected_shape))
Beispiel #3
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
            "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name)

        condition_value = node.in_port(0).data.get_value()
        resulting_tensors = [node.in_port(1).data.get_value(), node.in_port(2).data.get_value()]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
        assert output_shape is not None, 'Input shapes for node {} are not broadcast-able'.format(node_name)
        node.out_port(0).data.set_shape(output_shape)

        if condition_value is not None:
            if resulting_tensors[0] is not None:
                resulting_tensors[0] = bi_directional_broadcasting(resulting_tensors[0], b_shape)
            if resulting_tensors[1] is not None:
                resulting_tensors[1] = bi_directional_broadcasting(resulting_tensors[1], a_shape)
            condition_value = bi_directional_broadcasting(condition_value, output_shape)

            output_value = np.ma.where(condition_value, resulting_tensors[0], resulting_tensors[1])
            if condition_value.size != 1:
                if np.any(output_value == None):
                    # If any element of output value is None that means that we use the value from the 'then' or the
                    # 'else' tensor which is not defined, this means that we cannot perform value propagation.
                    output_value = None
            else:
                output_value = output_value.astype(resulting_tensors[not np.bool(condition_value.item(0))].dtype)

            if output_value is not None:
                node.out_port(0).data.set_value(output_value)
Beispiel #4
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        input_value = node.in_port(0).data.get_value()
        target_shape = node.in_port(1).data.get_value()
        assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(
            node_name)
        assert node.has_and_set(
            'mode'), 'Broadcasting mode is not defined for node "{}"'.format(
                node_name)

        PermuteInputs().set_input_permutation(node.in_node(1), node,
                                              'output:0', 'shape')

        if input_value is not None and not node.has_and_set(
                'stop_value_propagation'):
            if node.mode == 'numpy':
                node.out_port(0).data.set_value(
                    uni_directional_broadcasting(input_value, target_shape))
            elif node.mode == 'bidirectional':
                node.out_port(0).data.set_value(
                    bi_directional_broadcasting(input_value, target_shape))
            elif node.mode == 'explicit':
                axes_mapping = node.in_port(2).data.get_value()
                assert axes_mapping  is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
                                                  'is not supported. Node: `{}`'.format(node_name)
                PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                      'output:0', 'axis')
                axes_mapping = node.in_port(2).data.get_value()
                node.out_port(0).data.set_value(
                    explicit_broadcasting(input_value, target_shape,
                                          axes_mapping))
            else:
                raise Error('The node "{}" has unsupported mode "{}"'.format(
                    node_name, node.mode))
        else:
            if node.mode == 'numpy':
                node.out_port(0).data.set_shape(
                    uni_directional_shape_broadcasting(input_shape,
                                                       target_shape))
            elif node.mode == 'bidirectional':
                node.out_port(0).data.set_shape(
                    bi_directional_shape_broadcasting(input_shape,
                                                      target_shape))
            elif node.mode == 'explicit':
                axes_mapping = node.in_port(2).data.get_value()
                assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
                                                 'is not supported. Node: `{}`'.format(node_name)
                PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                      'output:0', 'axis')
                axes_mapping = node.in_port(2).data.get_value()
                new_shape, _ = explicit_shape_broadcasting(
                    input_shape, target_shape, axes_mapping)
                node.out_port(0).data.set_shape(new_shape)
            else:
                raise Error('The node "{}" has unsupported mode "{}"'.format(
                    node_name, node.mode))
Beispiel #5
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        num_inputs = len(connected_in_ports)
        assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name)
        equation = node.equation

        # parse the equation and extract input and output subscripts
        input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)

        # check that each operand has the corresponding input subscript
        assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \
                                                    "must match the number of input subscripts " \
                                                    "in `equation`".format(node_name)

        # check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels
        label_to_shape = {}
        for input_ind in range(num_inputs):
            input_shape = node.in_port(input_ind).data.get_shape()
            input_subscript = input_subscripts[input_ind]
            labels = Einsum.extract_subscript_labels(node_name, input_subscript)
            num_dims = len(input_shape)
            num_labels = len(labels)
            num_broadcasted_dims = num_dims - num_labels + 1
            dim_ind = 0
            label_ind = 0
            while label_ind < num_labels and dim_ind < num_dims:
                label = labels[label_ind]
                if label == "...":
                    sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims]
                    if label in label_to_shape.keys():
                        common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label])
                        assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \
                                                         "for Einsum node {}".format(node_name)
                        label_to_shape[label] = common_shape
                    else:
                        label_to_shape[label] = sub_shape
                    dim_ind += num_broadcasted_dims
                else:
                    dim_size = input_shape[dim_ind]
                    sub_shape = shape_array([dim_size])
                    assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \
                        "Sizes of dimensions with the same label of Einsum node {} " \
                        "must be compatible".format(node_name)
                    label_to_shape[label] = sub_shape
                    dim_ind += 1
                label_ind += 1

        # generate output shape based on the output subscript
        output_shape = shape_array([])
        labels = Einsum.extract_subscript_labels(node_name, output_subscript)
        for label in labels:
            assert label in label_to_shape.keys(), "The label in the output subscript must appear" \
                                                   " in input subscripts in equation {} " \
                                                   "of Einsum node {}".format(equation, node_name)
            output_shape = np.ma.concatenate((output_shape, label_to_shape[label]))

        node.out_port(0).data.set_shape(output_shape)
Beispiel #6
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
            "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name)

        condition_value = node.in_port(0).data.get_value()
        condition_shape = node.in_port(0).data.get_shape()
        resulting_tensors = [
            node.in_port(1).data.get_value(),
            node.in_port(2).data.get_value()
        ]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        broadcast_rule = node.soft_get('auto_broadcast', 'numpy')

        if broadcast_rule == 'numpy':
            msg = "In Select node '{}' condition and then/else shapes must be broadcastable. " \
                  "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                    node_name, condition_shape, a_shape, b_shape)

            output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
            assert output_shape is not None, msg

            # if Select was created from TF Where operations then 1D condition must have the same size
            # as 0-index dimension of output_shape. This condition is different from being numpy compatible
            # but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
            if node.has_valid('format') and node['format'] == 'tf' and len(
                    condition_shape) == 1:
                # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598
                msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \
                         "must be matching with the first dimension of then/else branches. " \
                         "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                            node_name, condition_shape, a_shape, b_shape)

                assert condition_shape[0] == output_shape[0], msg_tf
                condition_shape = np.concatenate(
                    (condition_shape, np.ones(len(output_shape) - 1)))

            output_shape = bi_directional_shape_broadcasting(
                output_shape, condition_shape)
            assert output_shape is not None, msg

        elif broadcast_rule == 'pdpd':
            # todo: add pdpd broadcasting rule
            # note that additionally to output_shape resulting_tensors must be broadcasted as well
            raise Error("PDPD broadcasting rule is not implemented yet")
        else:  # broadcasting is not allowed
            assert compatible_shapes(a_shape, b_shape) and compatible_shapes(condition_shape, a_shape), \
                'In node \'{}\' for Select operation when broadcasting is off all inputs must be of the same shape. ' \
                'But instead got: cond_shape={}, then_shape={}, else_shape={}'.format(
                    node_name, condition_shape, a_shape, b_shape)
            output_shape = shape_array([
                i if i is not dynamic_dimension else j
                for i, j in zip(a_shape, b_shape)
            ])

        node.out_port(0).data.set_shape(output_shape)

        if condition_value is not None:
            if is_fully_defined(condition_value) and np.all(
                    condition_value == condition_value.item(0)):
                # in some graphs Select condition is always True[False] and
                # one of the branches is None (which is not selected)
                # if we use np.where for such cases then dtype of output_value will be object (non numeric type)
                # and subsequent numpy operation on such tensors will fail
                output_value = resulting_tensors[not np.
                                                 bool(condition_value.item(0))]
                if output_value is None:
                    return
                if broadcast_rule == 'numpy':
                    output_value = bi_directional_broadcasting(
                        output_value, output_shape)
                elif broadcast_rule == 'pdpd':
                    # todo: add pdpd broadcasting rule
                    raise Error(
                        "PDPD broadcasting rule is not implemented yet")

                node.out_port(0).data.set_value(output_value)
            elif resulting_tensors[0] is not None and resulting_tensors[
                    1] is not None:
                output_value = np.ma.where(condition_value,
                                           resulting_tensors[0],
                                           resulting_tensors[1])
                node.out_port(0).data.set_value(output_value)