コード例 #1
0
 def f(x):
     # pylint: disable=g-long-lambda
     x = array_creation.asarray(x)
     return array_creation.asarray(
         utils.cond(
             utils.greater(n, tf.rank(x)), lambda: array_methods.reshape(
                 x, new_shape(n, tf.shape(x.data))).data, lambda: x.data))
コード例 #2
0
ファイル: array_ops.py プロジェクト: eudora-jia/trax
def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
    """Creates an ndarray with the contents of val.

  Args:
    val: array_like. Could be an ndarray, a Tensor or any object that can be
      converted to a Tensor using `tf.convert_to_tensor`.
    dtype: Optional, defaults to dtype of the `val`. The type of the resulting
      ndarray. Could be a python type, a NumPy type or a TensorFlow `DType`.
    copy: Determines whether to create a copy of the backing buffer. Since
      Tensors are immutable, a copy is made only if val is placed on a different
      device than the current one. Even if `copy` is False, a new Tensor may
      need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
      is an ndarray or a Tensor.
    ndmin: The minimum rank of the returned array.

  Returns:
    An ndarray.
  """
    if dtype:
        dtype = utils.result_type(dtype)
    if isinstance(val, arrays_lib.ndarray):
        result_t = val.data
    else:
        result_t = val

    if copy and isinstance(result_t, tf.Tensor):
        # Note: In eager mode, a copy of `result_t` is made only if it is not on
        # the context device.
        result_t = tf.identity(result_t)

    if not isinstance(result_t, tf.Tensor):
        if not dtype:
            dtype = utils.result_type(result_t)
        # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
        # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
        # while np.array allows them. We need to convert-then-cast.
        def maybe_data(x):
            if isinstance(x, arrays_lib.ndarray):
                return x.data
            return x

        # Handles lists of ndarrays
        result_t = tf.nest.map_structure(maybe_data, result_t)
        result_t = arrays_lib.convert_to_tensor(result_t)
        result_t = tf.cast(result_t, dtype=dtype)
    elif dtype:
        result_t = tf.cast(result_t, dtype)
    ndims = tf.rank(result_t)

    def true_fn():
        old_shape = tf.shape(result_t)
        new_shape = tf.concat([tf.ones(ndmin - ndims, tf.int32), old_shape],
                              axis=0)
        return tf.reshape(result_t, new_shape)

    result_t = utils.cond(utils.greater(ndmin, ndims), true_fn,
                          lambda: result_t)
    return arrays_lib.tensor_to_ndarray(result_t)