コード例 #1
0
ファイル: math_ops.py プロジェクト: zhaoqiuye/trax
def kron(a, b):
    # pylint: disable=protected-access,g-complex-comprehension
    a, b = array_ops._promote_dtype(a, b)
    ndim = max(a.ndim, b.ndim)
    if a.ndim < ndim:
        a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape))
    if b.ndim < ndim:
        b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape))
    a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
    b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
    out_shape = tuple(np.multiply(a.shape, b.shape))
    return array_ops.reshape(a_reshaped * b_reshaped, out_shape)
コード例 #2
0
 def run_test(arr, newshape, *args, **kwargs):
   for fn1 in self.array_transforms:
     for fn2 in self.array_transforms:
       arr_arg = fn1(arr)
       newshape_arg = fn2(newshape)
       self.match(
           array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs),
           np.reshape(arr_arg, newshape, *args, **kwargs))
コード例 #3
0
 def run_test(arr, newshape, *args, **kwargs):
     for fn1 in self.array_transforms:
         for fn2 in self.array_transforms:
             arr_arg = fn1(arr)
             newshape_arg = fn2(newshape)
             # If reshape is called on a Tensor, it calls out to the Tensor.reshape
             # method.
             np_arr_arg = arr_arg
             if isinstance(np_arr_arg, tf.Tensor):
                 np_arr_arg = np_arr_arg.numpy()
             self.match(
                 array_ops.reshape(arr_arg, newshape_arg, *args,
                                   **kwargs),
                 np.reshape(np_arr_arg, newshape, *args, **kwargs))