Ejemplo n.º 1
0
def kron(a, b):  # pylint: disable=missing-function-docstring
    # pylint: disable=protected-access,g-complex-comprehension
    a, b = np_array_ops._promote_dtype(a, b)
    t_a = np_utils.cond(
        a.ndim < b.ndim,
        lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
            a, np_array_ops._pad_left_to(b.ndim, a.shape)),
        lambda: a)
    t_b = np_utils.cond(
        b.ndim < a.ndim,
        lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
            b, np_array_ops._pad_left_to(a.ndim, b.shape)),
        lambda: b)

    def _make_shape(shape, prepend):
        ones = array_ops.ones_like(shape)
        if prepend:
            shapes = [ones, shape]
        else:
            shapes = [shape, ones]
        return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])

    a_shape = array_ops.shape(t_a)
    b_shape = array_ops.shape(t_b)
    a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
    b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
    out_shape = a_shape * b_shape
    return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
Ejemplo n.º 2
0
def vdot(a, b):  # pylint: disable=missing-docstring
    a, b = np_array_ops._promote_dtype(a, b)
    a = np_array_ops.reshape(a, [-1])
    b = np_array_ops.reshape(b, [-1])
    if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
        a = conj(a)
    return dot(a, b)
Ejemplo n.º 3
0
def kron(a, b):  # pylint: disable=missing-function-docstring
  # pylint: disable=protected-access,g-complex-comprehension
  a, b = np_array_ops._promote_dtype(a, b)
  ndim = max(a.ndim, b.ndim)
  if a.ndim < ndim:
    a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape))
  if b.ndim < ndim:
    b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape))
  a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
  b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
  out_shape = tuple(np.multiply(a.shape, b.shape))
  return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
Ejemplo n.º 4
0
 def run_test(arr, newshape, *args, **kwargs):
   for fn1 in self.array_transforms:
     for fn2 in self.array_transforms:
       arr_arg = fn1(arr)
       newshape_arg = fn2(newshape)
       self.match(
           np_array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs),
           np.reshape(arr_arg, newshape, *args, **kwargs))