コード例 #1
0
ファイル: array_ops.py プロジェクト: rouniuyizu/trax
 def new_shape(_, old_shape):
   # pylint: disable=g-long-lambda
   ndim_ = tf.size(old_shape)
   return utils.cond(
       ndim_ == 0, lambda: tf.constant([1, 1, 1], dtype=tf.int32),
       lambda: utils.cond(
           ndim_ == 1, lambda: tf.pad(old_shape, [[1, 1]], constant_values=1),
           lambda: tf.pad(old_shape, [[0, 1]], constant_values=1)))
コード例 #2
0
 def f(a, b):  # pylint: disable=missing-docstring
     return utils.cond(
         utils.logical_or(tf.rank(a) == 0,
                          tf.rank(b) == 0),
         lambda: a * b,
         lambda: utils.cond(  # pylint: disable=g-long-lambda
             tf.rank(b) == 1, lambda: tf.tensordot(a, b, axes=[[-1], [-1]]),
             lambda: tf.tensordot(a, b, axes=[[-1], [-2]])))
コード例 #3
0
ファイル: math.py プロジェクト: jaeyounkim/trax
 def f(x1, x2):
   try:
     return utils.cond(tf.rank(x2) == 1,
                       lambda: tf.tensordot(x1, x2, axes=1),
                       lambda: utils.cond(tf.rank(x1) == 1,  # pylint: disable=g-long-lambda
                                          lambda: tf.tensordot(  # pylint: disable=g-long-lambda
                                              x1, x2, axes=[[0], [-2]]),
                                          lambda: tf.matmul(x1, x2)))
   except tf.errors.InvalidArgumentError as err:
     six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2])
コード例 #4
0
ファイル: math_ops.py プロジェクト: victorustc/trax
 def f(a, b):  # pylint: disable=missing-docstring
   # We can't assign to captured variable `axisa`, so make a new variable
   axis_a = axisa
   axis_b = axisb
   axis_c = axisc
   if axis is not None:
     axis_a = axis
     axis_b = axis
     axis_c = axis
   if axis_a < 0:
     axis_a = utils.add(axis_a, tf.rank(a))
   if axis_b < 0:
     axis_b = utils.add(axis_b, tf.rank(b))
   def maybe_move_axis_to_last(a, axis):
     def move_axis_to_last(a, axis):
       return tf.transpose(
           a, tf.concat(
               [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
               axis=0))
     return utils.cond(
         axis == utils.subtract(tf.rank(a), 1),
         lambda: a,
         lambda: move_axis_to_last(a, axis))
   a = maybe_move_axis_to_last(a, axis_a)
   b = maybe_move_axis_to_last(b, axis_b)
   a_dim = utils.getitem(tf.shape(a), -1)
   b_dim = utils.getitem(tf.shape(b), -1)
   def maybe_pad_0(a, size_of_last_dim):
     def pad_0(a):
       return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32),
                                   tf.constant([[0, 1]], tf.int32)], axis=0))
     return utils.cond(size_of_last_dim == 2,
                       lambda: pad_0(a),
                       lambda: a)
   a = maybe_pad_0(a, a_dim)
   b = maybe_pad_0(b, b_dim)
   c = tf.linalg.cross(*utils.tf_broadcast(a, b))
   if axis_c < 0:
     axis_c = utils.add(axis_c, tf.rank(c))
   def move_last_to_axis(a, axis):
     r = tf.rank(a)
     return tf.transpose(
         a, tf.concat(
             [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0))
   c = utils.cond(
       (a_dim == 2) & (b_dim == 2),
       lambda: c[..., 2],
       lambda: utils.cond(  # pylint: disable=g-long-lambda
           axis_c == utils.subtract(tf.rank(c), 1),
           lambda: c,
           lambda: move_last_to_axis(c, axis_c)))
   return c
コード例 #5
0
 def f(x):
     # pylint: disable=g-long-lambda
     x = array_creation.asarray(x)
     return array_creation.asarray(
         utils.cond(
             utils.greater(n, tf.rank(x)), lambda: array_methods.reshape(
                 x, new_shape(n, tf.shape(x.data))).data, lambda: x.data))
コード例 #6
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a).data

  maybe_rank = a.shape.rank
  if maybe_rank is not None and offset == 0 and (
      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                   axis2 == -1):
    return utils.tensor_to_ndarray(tf.linalg.diag_part(a))

  a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

  a_shape = tf.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = utils.cond(
      utils.logical_or(
          utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
          utils.greater_equal(offset, utils.getitem(a_shape, -1)),
      ), _zeros, lambda: (a, offset))

  a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
  return a
コード例 #7
0
def diag(v, k=0):  # pylint: disable=missing-docstring
  """Raises an error if input is not 1- or 2-d."""
  v = asarray(v).data
  v_rank = tf.rank(v)

  v.shape.with_rank_at_most(2)

  # TODO(nareshmodi): Consider a utils.Assert version that will fail during
  # tracing time if the shape is known.
  tf.debugging.Assert(
      utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank])

  def _diag(v, k):
    return utils.cond(
        tf.equal(tf.size(v), 0),
        lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype),
        lambda: tf.linalg.diag(v, k=k))

  def _diag_part(v, k):
    v_shape = tf.shape(v)
    v, k = utils.cond(
        utils.logical_or(
            utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
            utils.greater_equal(k, utils.getitem(v_shape, 1)),
        ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
    result = tf.linalg.diag_part(v, k=k)
    return result

  result = utils.cond(
      tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
  return utils.tensor_to_ndarray(result)
コード例 #8
0
ファイル: math_ops.py プロジェクト: victorustc/trax
 def maybe_pad_0(a, size_of_last_dim):
   def pad_0(a):
     return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32),
                                 tf.constant([[0, 1]], tf.int32)], axis=0))
   return utils.cond(size_of_last_dim == 2,
                     lambda: pad_0(a),
                     lambda: a)
コード例 #9
0
ファイル: array_ops.py プロジェクト: eudora-jia/trax
def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
    """Creates an ndarray with the contents of val.

  Args:
    val: array_like. Could be an ndarray, a Tensor or any object that can be
      converted to a Tensor using `tf.convert_to_tensor`.
    dtype: Optional, defaults to dtype of the `val`. The type of the resulting
      ndarray. Could be a python type, a NumPy type or a TensorFlow `DType`.
    copy: Determines whether to create a copy of the backing buffer. Since
      Tensors are immutable, a copy is made only if val is placed on a different
      device than the current one. Even if `copy` is False, a new Tensor may
      need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
      is an ndarray or a Tensor.
    ndmin: The minimum rank of the returned array.

  Returns:
    An ndarray.
  """
    if dtype:
        dtype = utils.result_type(dtype)
    if isinstance(val, arrays_lib.ndarray):
        result_t = val.data
    else:
        result_t = val

    if copy and isinstance(result_t, tf.Tensor):
        # Note: In eager mode, a copy of `result_t` is made only if it is not on
        # the context device.
        result_t = tf.identity(result_t)

    if not isinstance(result_t, tf.Tensor):
        if not dtype:
            dtype = utils.result_type(result_t)
        # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
        # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
        # while np.array allows them. We need to convert-then-cast.
        def maybe_data(x):
            if isinstance(x, arrays_lib.ndarray):
                return x.data
            return x

        # Handles lists of ndarrays
        result_t = tf.nest.map_structure(maybe_data, result_t)
        result_t = arrays_lib.convert_to_tensor(result_t)
        result_t = tf.cast(result_t, dtype=dtype)
    elif dtype:
        result_t = tf.cast(result_t, dtype)
    ndims = tf.rank(result_t)

    def true_fn():
        old_shape = tf.shape(result_t)
        new_shape = tf.concat([tf.ones(ndmin - ndims, tf.int32), old_shape],
                              axis=0)
        return tf.reshape(result_t, new_shape)

    result_t = utils.cond(utils.greater(ndmin, ndims), true_fn,
                          lambda: result_t)
    return arrays_lib.tensor_to_ndarray(result_t)
コード例 #10
0
def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
    if axis is not None and not isinstance(axis, six.integer_types):
        # TODO(wangpeng): Support tuple of ints as `axis`
        raise ValueError('`axis` must be an integer. Tuple of ints is not '
                         'supported yet. Got type: %s' % type(axis))
    a = array_ops.array(a)
    if weights is None:  # Treat all weights as 1
        if not np.issubdtype(a.dtype, np.inexact):
            a = a.astype(
                utils.result_type(a.dtype, dtypes.default_float_type()))
        avg = tf.reduce_mean(a.data, axis=axis)
        if returned:
            if axis is None:
                weights_sum = tf.size(a.data)
            else:
                weights_sum = tf.shape(a.data)[axis]
            weights_sum = tf.cast(weights_sum, a.data.dtype)
    else:
        if np.issubdtype(a.dtype, np.inexact):
            out_dtype = utils.result_type(a.dtype, weights)
        else:
            out_dtype = utils.result_type(a.dtype, weights,
                                          dtypes.default_float_type())
        a = array_ops.array(a, out_dtype).data
        weights = array_ops.array(weights, out_dtype).data

        def rank_equal_case():
            tf.debugging.Assert(
                tf.reduce_all(tf.shape(a) == tf.shape(weights)),
                [tf.shape(a), tf.shape(weights)])
            weights_sum = tf.reduce_sum(weights, axis=axis)
            avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum
            return avg, weights_sum

        if axis is None:
            avg, weights_sum = rank_equal_case()
        else:

            def rank_not_equal_case():
                tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)])
                weights_sum = tf.reduce_sum(weights)
                axes = tf.convert_to_tensor([[axis], [0]])
                avg = tf.tensordot(a, weights, axes) / weights_sum
                return avg, weights_sum

            # We condition on rank rather than shape equality, because if we do the
            # latter, when the shapes are partially unknown but the ranks are known
            # and different, utils.cond will run shape checking on the true branch,
            # which will raise a shape-checking error.
            avg, weights_sum = utils.cond(
                tf.rank(a) == tf.rank(weights), rank_equal_case,
                rank_not_equal_case)

    avg = array_ops.array(avg)
    if returned:
        weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data))
        return avg, weights_sum
    return avg
コード例 #11
0
ファイル: array_ops.py プロジェクト: rouniuyizu/trax
def hstack(tup):
  arrays = [atleast_1d(a) for a in tup]
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays
  ]
  rank = tf.rank(unwrapped_arrays[0])
  return utils.cond(rank == 1, lambda: tf.concat(unwrapped_arrays, axis=0),
                    lambda: tf.concat(unwrapped_arrays, axis=1))
コード例 #12
0
 def _diag_part(v, k):
   v_shape = tf.shape(v)
   v, k = utils.cond(
       utils.logical_or(
           utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
           utils.greater_equal(k, utils.getitem(v_shape, 1)),
       ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
   result = tf.linalg.diag_part(v, k=k)
   return result
コード例 #13
0
ファイル: math_ops.py プロジェクト: victorustc/trax
 def maybe_move_axis_to_last(a, axis):
   def move_axis_to_last(a, axis):
     return tf.transpose(
         a, tf.concat(
             [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
             axis=0))
   return utils.cond(
       axis == utils.subtract(tf.rank(a), 1),
       lambda: a,
       lambda: move_axis_to_last(a, axis))
コード例 #14
0
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    a = array_ops.asarray(a).data

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return utils.tensor_to_ndarray(tf.linalg.trace(a))

    a_rank = tf.rank(a)
    if axis1 < 0:
        axis1 += a_rank
    if axis2 < 0:
        axis2 += a_rank

    minaxis = tf.minimum(axis1, axis2)
    maxaxis = tf.maximum(axis1, axis2)

    # Move axes of interest to the end.
    range_rank = tf.range(a_rank)
    perm = tf.concat([
        range_rank[0:minaxis], range_rank[minaxis + 1:maxaxis],
        range_rank[maxaxis + 1:], [axis1, axis2]
    ],
                     axis=0)
    a = tf.transpose(a, perm)

    a_shape = tf.shape(a)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = utils.cond(
        utils.logical_or(
            utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
            utils.greater_equal(offset, utils.getitem(a_shape, -1)),
        ), lambda: (tf.zeros_like(a), 0), lambda: (a, offset))

    a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
    return array_ops.sum(a, -1, dtype)
コード例 #15
0
 def f(a, b):
     return utils.cond(utils.logical_or(tf.rank(a) == 0,
                                        tf.rank(b) == 0), lambda: a * b,
                       lambda: tf.tensordot(a, b, axes=[[-1], [-1]]))
コード例 #16
0
 def _diag(v, k):
   return utils.cond(
       tf.equal(tf.size(v), 0),
       lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype),
       lambda: tf.linalg.diag(v, k=k))