def kron(a, b): # pylint: disable=protected-access,g-complex-comprehension a, b = array_ops._promote_dtype(a, b) ndim = max(a.ndim, b.ndim) if a.ndim < ndim: a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape)) if b.ndim < ndim: b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape)) a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) out_shape = tuple(np.multiply(a.shape, b.shape)) return array_ops.reshape(a_reshaped * b_reshaped, out_shape)
def run_test(arr, newshape, *args, **kwargs): for fn1 in self.array_transforms: for fn2 in self.array_transforms: arr_arg = fn1(arr) newshape_arg = fn2(newshape) self.match( array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs), np.reshape(arr_arg, newshape, *args, **kwargs))
def run_test(arr, newshape, *args, **kwargs): for fn1 in self.array_transforms: for fn2 in self.array_transforms: arr_arg = fn1(arr) newshape_arg = fn2(newshape) # If reshape is called on a Tensor, it calls out to the Tensor.reshape # method. np_arr_arg = arr_arg if isinstance(np_arr_arg, tf.Tensor): np_arr_arg = np_arr_arg.numpy() self.match( array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs), np.reshape(np_arr_arg, newshape, *args, **kwargs))