def _broadcast(shape1: TensorFluentShape, shape2: TensorFluentShape) -> TensorFluentShape: s1, s2 = TensorFluentShape.broadcast(shape1, shape2) s1 = s1 if s1 is not None else shape1.as_list() s2 = s2 if s2 is not None else shape2.as_list() x1, x2 = np.zeros(s1), np.zeros(s2) y = np.broadcast(x1, x2) return TensorFluentShape(y.shape, batch=(shape1.batch or shape2.batch))
def _binary_op(cls, x: 'TensorFluent', y: 'TensorFluent', op: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], dtype: tf.DType) -> 'TensorFluent': '''Returns a TensorFluent for the binary `op` applied to fluents `x` and `y`. Args: x: The first operand. y: The second operand. op: The binary operator. dtype: The output's data type. Returns: A TensorFluent wrapping the binary operator's output. ''' # scope s1 = x.scope.as_list() s2 = y.scope.as_list() scope, perm1, perm2 = TensorFluentScope.broadcast(s1, s2) if x.batch and perm1 != []: perm1 = [0] + [p+1 for p in perm1] if y.batch and perm2 != []: perm2 = [0] + [p+1 for p in perm2] x = x.transpose(perm1) y = y.transpose(perm2) # shape reshape1, reshape2 = TensorFluentShape.broadcast(x.shape, y.shape) if reshape1 is not None: x = x.reshape(reshape1) if reshape2 is not None: y = y.reshape(reshape2) # dtype x = x.cast(dtype) y = y.cast(dtype) # operation t = op(x.tensor, y.tensor) # batch batch = x.batch or y.batch return TensorFluent(t, scope, batch=batch)
def test_broadcast(self): tests = [ (TensorFluentShape([], False), TensorFluentShape([], False), None, None), (TensorFluentShape([8], False), TensorFluentShape([], False), None, None), (TensorFluentShape([], False), TensorFluentShape([8], False), None, None), (TensorFluentShape([8, 8], False), TensorFluentShape([8], False), None, None), (TensorFluentShape([8], False), TensorFluentShape([8, 8], False), None, None), (TensorFluentShape([100], True), TensorFluentShape([100], True), None, None), (TensorFluentShape([100, 8], True), TensorFluentShape([100], True), None, [100, 1]), (TensorFluentShape([100], True), TensorFluentShape([100, 8], True), [100, 1], None), (TensorFluentShape([100, 8, 8], True), TensorFluentShape([100], True), None, [100, 1, 1]), (TensorFluentShape([100], True), TensorFluentShape([100, 8, 8], True), [100, 1, 1], None), (TensorFluentShape([100, 8, 8], True), TensorFluentShape([100, 8], True), None, [100, 1, 8]), (TensorFluentShape([100, 8], True), TensorFluentShape([100, 8, 8], True), [100, 1, 8], None), (TensorFluentShape([100], True), TensorFluentShape([], False), None, None), (TensorFluentShape([], False), TensorFluentShape([], True), None, None), (TensorFluentShape([100], True), TensorFluentShape([], False), None, None), (TensorFluentShape([100], True), TensorFluentShape([8], False), [100, 1], None), (TensorFluentShape([8], False), TensorFluentShape([100], True), None, [100, 1]), (TensorFluentShape([100], True), TensorFluentShape([8, 7], False), [100, 1, 1], None), (TensorFluentShape([8, 7], False), TensorFluentShape([100], True), None, [100, 1, 1]), (TensorFluentShape([100, 8], True), TensorFluentShape([], False), None, None), (TensorFluentShape([], False), TensorFluentShape([100, 8], True), None, None), (TensorFluentShape([100, 8], True), TensorFluentShape([8], False), None, None), (TensorFluentShape([8], False), TensorFluentShape([100, 8], True), None, None), (TensorFluentShape([100, 8, 7], True), TensorFluentShape([7], False), None, [1, 7]), (TensorFluentShape([7], False), TensorFluentShape([100, 8, 7], True), [1, 7], None), (TensorFluentShape([100, 7, 8], True), TensorFluentShape([7, 8], False), None, None), (TensorFluentShape([7, 8], False), TensorFluentShape([100, 7, 8], True), None, None), (TensorFluentShape([8, 8], False), TensorFluentShape([100, 8], True), None, [100, 1, 8]), (TensorFluentShape([100, 8], True), TensorFluentShape([8, 8], False), [100, 1, 8], None), (TensorFluentShape([2, 2], False), TensorFluentShape([1, 2], True), None, [1, 1, 2]), (TensorFluentShape([1, 2], True), TensorFluentShape([2, 2], False), [1, 1, 2], None), ] for s1, s2, ss1, ss2 in tests: reshape1, reshape2 = TensorFluentShape.broadcast(s1, s2) if ss1 is None: self.assertIsNone(reshape1) else: self.assertListEqual(reshape1, ss1) if ss2 is None: self.assertIsNone(reshape2) else: self.assertListEqual(reshape2, ss2)