示例#1
0
def _broadcast(shape1: TensorFluentShape,
               shape2: TensorFluentShape) -> TensorFluentShape:
    s1, s2 = TensorFluentShape.broadcast(shape1, shape2)
    s1 = s1 if s1 is not None else shape1.as_list()
    s2 = s2 if s2 is not None else shape2.as_list()
    x1, x2 = np.zeros(s1), np.zeros(s2)
    y = np.broadcast(x1, x2)
    return TensorFluentShape(y.shape, batch=(shape1.batch or shape2.batch))
示例#2
0
    def _binary_op(cls,
            x: 'TensorFluent',
            y: 'TensorFluent',
            op: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
            dtype: tf.DType) -> 'TensorFluent':
        '''Returns a TensorFluent for the binary `op` applied to fluents `x` and `y`.

        Args:
            x: The first operand.
            y: The second operand.
            op: The binary operator.
            dtype: The output's data type.

        Returns:
            A TensorFluent wrapping the binary operator's output.
        '''
        # scope
        s1 = x.scope.as_list()
        s2 = y.scope.as_list()
        scope, perm1, perm2 = TensorFluentScope.broadcast(s1, s2)
        if x.batch and perm1 != []:
            perm1 = [0] + [p+1 for p in perm1]
        if y.batch and perm2 != []:
            perm2 = [0] + [p+1 for p in perm2]
        x = x.transpose(perm1)
        y = y.transpose(perm2)

        # shape
        reshape1, reshape2 = TensorFluentShape.broadcast(x.shape, y.shape)
        if reshape1 is not None:
            x = x.reshape(reshape1)
        if reshape2 is not None:
            y = y.reshape(reshape2)

        # dtype
        x = x.cast(dtype)
        y = y.cast(dtype)

        # operation
        t = op(x.tensor, y.tensor)

        # batch
        batch = x.batch or y.batch

        return TensorFluent(t, scope, batch=batch)
示例#3
0
    def test_broadcast(self):
        tests = [
            (TensorFluentShape([],
                               False), TensorFluentShape([],
                                                         False), None, None),
            (TensorFluentShape([8],
                               False), TensorFluentShape([],
                                                         False), None, None),
            (TensorFluentShape([],
                               False), TensorFluentShape([8],
                                                         False), None, None),
            (TensorFluentShape([8, 8],
                               False), TensorFluentShape([8],
                                                         False), None, None),
            (TensorFluentShape([8],
                               False), TensorFluentShape([8, 8],
                                                         False), None, None),
            (TensorFluentShape([100],
                               True), TensorFluentShape([100],
                                                        True), None, None),
            (TensorFluentShape([100, 8],
                               True), TensorFluentShape([100],
                                                        True), None, [100, 1]),
            (TensorFluentShape([100],
                               True), TensorFluentShape([100, 8],
                                                        True), [100, 1], None),
            (TensorFluentShape([100, 8, 8], True),
             TensorFluentShape([100], True), None, [100, 1, 1]),
            (TensorFluentShape([100], True),
             TensorFluentShape([100, 8, 8], True), [100, 1, 1], None),
            (TensorFluentShape([100, 8, 8], True),
             TensorFluentShape([100, 8], True), None, [100, 1, 8]),
            (TensorFluentShape([100, 8], True),
             TensorFluentShape([100, 8, 8], True), [100, 1, 8], None),
            (TensorFluentShape([100],
                               True), TensorFluentShape([],
                                                        False), None, None),
            (TensorFluentShape([],
                               False), TensorFluentShape([],
                                                         True), None, None),
            (TensorFluentShape([100],
                               True), TensorFluentShape([],
                                                        False), None, None),
            (TensorFluentShape([100],
                               True), TensorFluentShape([8],
                                                        False), [100,
                                                                 1], None),
            (TensorFluentShape([8],
                               False), TensorFluentShape([100],
                                                         True), None, [100,
                                                                       1]),
            (TensorFluentShape([100],
                               True), TensorFluentShape([8, 7],
                                                        False), [100, 1,
                                                                 1], None),
            (TensorFluentShape([8, 7], False), TensorFluentShape([100], True),
             None, [100, 1, 1]),
            (TensorFluentShape([100, 8],
                               True), TensorFluentShape([],
                                                        False), None, None),
            (TensorFluentShape([],
                               False), TensorFluentShape([100, 8],
                                                         True), None, None),
            (TensorFluentShape([100, 8],
                               True), TensorFluentShape([8],
                                                        False), None, None),
            (TensorFluentShape([8],
                               False), TensorFluentShape([100, 8],
                                                         True), None, None),
            (TensorFluentShape([100, 8, 7],
                               True), TensorFluentShape([7],
                                                        False), None, [1, 7]),
            (TensorFluentShape([7],
                               False), TensorFluentShape([100, 8, 7],
                                                         True), [1, 7], None),
            (TensorFluentShape([100, 7, 8],
                               True), TensorFluentShape([7, 8],
                                                        False), None, None),
            (TensorFluentShape([7, 8],
                               False), TensorFluentShape([100, 7, 8],
                                                         True), None, None),
            (TensorFluentShape([8, 8], False),
             TensorFluentShape([100, 8], True), None, [100, 1, 8]),
            (TensorFluentShape([100, 8],
                               True), TensorFluentShape([8, 8],
                                                        False), [100, 1,
                                                                 8], None),
            (TensorFluentShape([2, 2], False), TensorFluentShape([1, 2], True),
             None, [1, 1, 2]),
            (TensorFluentShape([1, 2],
                               True), TensorFluentShape([2, 2],
                                                        False), [1, 1,
                                                                 2], None),
        ]

        for s1, s2, ss1, ss2 in tests:
            reshape1, reshape2 = TensorFluentShape.broadcast(s1, s2)
            if ss1 is None:
                self.assertIsNone(reshape1)
            else:
                self.assertListEqual(reshape1, ss1)

            if ss2 is None:
                self.assertIsNone(reshape2)
            else:
                self.assertListEqual(reshape2, ss2)