예제 #1
0
def vdot(a, b):  # pylint: disable=missing-docstring
    a, b = np_array_ops._promote_dtype(a, b)
    a = np_array_ops.reshape(a, [-1])
    b = np_array_ops.reshape(b, [-1])
    if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
        a = conj(a)
    return dot(a, b)
예제 #2
0
def kron(a, b):  # pylint: disable=missing-function-docstring
    # pylint: disable=protected-access,g-complex-comprehension
    a, b = np_array_ops._promote_dtype(a, b)
    t_a = np_utils.cond(
        a.ndim < b.ndim,
        lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
            a, np_array_ops._pad_left_to(b.ndim, a.shape)),
        lambda: a)
    t_b = np_utils.cond(
        b.ndim < a.ndim,
        lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
            b, np_array_ops._pad_left_to(a.ndim, b.shape)),
        lambda: b)

    def _make_shape(shape, prepend):
        ones = array_ops.ones_like(shape)
        if prepend:
            shapes = [ones, shape]
        else:
            shapes = [shape, ones]
        return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])

    a_shape = array_ops.shape(t_a)
    b_shape = array_ops.shape(t_b)
    a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
    b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
    out_shape = a_shape * b_shape
    return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
예제 #3
0
def _bin_op(tf_fun, a, b, promote=True):
  if promote:
    a, b = np_array_ops._promote_dtype(a, b)  # pylint: disable=protected-access
  else:
    a = np_array_ops.array(a)
    b = np_array_ops.array(b)
  return np_utils.tensor_to_ndarray(tf_fun(a.data, b.data))
예제 #4
0
def einsum(subscripts, *operands, **kwargs):  # pylint: disable=missing-docstring
    casting = kwargs.get('casting', 'safe')
    optimize = kwargs.get('optimize', False)
    if casting == 'safe':
        operands = np_array_ops._promote_dtype(*operands)  # pylint: disable=protected-access
    elif casting == 'no':
        operands = [np_array_ops.asarray(x) for x in operands]
    else:
        raise ValueError('casting policy not supported: %s' % casting)
    if not optimize:
        # TF doesn't have a "no optimization" option.
        # TODO(wangpeng): Print a warning that np and tf use different
        #   optimizations.
        tf_optimize = 'greedy'
    elif optimize == True:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
        tf_optimize = 'greedy'
    elif optimize == 'greedy':
        tf_optimize = 'greedy'
    elif optimize == 'optimal':
        tf_optimize = 'optimal'
    else:
        raise ValueError('`optimize` method not supported: %s' % optimize)
    operands = [x.data for x in operands]
    res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
    res = np_utils.tensor_to_ndarray(res)
    return res
예제 #5
0
def einsum(subscripts, *operands, **kwargs):  # pylint: disable=missing-docstring
  casting = kwargs.get('casting', 'safe')
  optimize = kwargs.get('optimize', False)
  if casting == 'safe':
    operands = np_array_ops._promote_dtype(*operands)  # pylint: disable=protected-access
  elif casting == 'no':
    operands = [np_array_ops.asarray(x) for x in operands]
  else:
    raise ValueError(
        'Invalid value for argument `casting`. '
        f'Expected casting="safe" or casting="no". Received: casting={casting}')
  if not optimize:
    # TF doesn't have a "no optimization" option.
    # TODO(wangpeng): Print a warning that np and tf use different
    #   optimizations.
    tf_optimize = 'greedy'
  elif optimize == True:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
    tf_optimize = 'greedy'
  elif optimize == 'greedy':
    tf_optimize = 'greedy'
  elif optimize == 'optimal':
    tf_optimize = 'optimal'
  else:
    raise ValueError(
        'Invalid value for argument `optimize`. '
        'Expected one of {True, "greedy", "optimal"}. '
        f'Received: optimize={optimize}')

  res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
  return res
예제 #6
0
def clip(a, a_min, a_max):  # pylint: disable=missing-docstring
  if a_min is None and a_max is None:
    raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.')
  if a_min is None:
    return minimum(a, a_max)
  elif a_max is None:
    return maximum(a, a_min)
  else:
    a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max)  # pylint: disable=protected-access
    return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
예제 #7
0
def kron(a, b):  # pylint: disable=missing-function-docstring
  # pylint: disable=protected-access,g-complex-comprehension
  a, b = np_array_ops._promote_dtype(a, b)
  ndim = max(a.ndim, b.ndim)
  if a.ndim < ndim:
    a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape))
  if b.ndim < ndim:
    b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape))
  a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
  b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
  out_shape = tuple(np.multiply(a.shape, b.shape))
  return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)