示例#1
0
文件: arrays.py 项目: youngjt/trax
    def __init__(self, shape, dtype=float, buffer=None):  # pylint: disable=redefined-builtin
        """Initializes an ndarray.

    This is a low level interface for building ndarrays and should be avoided.
    Users should instead use methods in array_creation.py.

    This class provides a numpy.ndarray like interface for a TF Tensor with a
    fully-defined shape. Note that, unlike the backing buffer of np.ndarray,
    Tensors are immutable. So, operations like `__setitem__` are performed by
    replacing the Tensor. This restricts the ability to implement NumPy `view`
    semantics.

    Compared to numpy.ndarray, this does not support `offset`, `strides`
    and `order` arguments.

    Args:
      shape: The shape of the array. Must be a scalar, an iterable of integers
        or a `TensorShape` object.
      dtype: Optional. The dtype of the array. Must be a python type, a numpy
        type or a tensorflow `DType` object.
      buffer: Optional. The backing buffer of the array. Must have shape
        `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`.

    Raises:
      ValueError: If `buffer` is specified and its shape does not match
       `shape`.
    """
        if dtype and not isinstance(dtype, tf.DType):
            dtype = tf.as_dtype(np.dtype(dtype))
        if buffer is None:
            buffer = tf.zeros(shape, dtype=dtype)
        else:
            if isinstance(buffer, ndarray):
                buffer = buffer.data
            elif isinstance(buffer, np.ndarray):
                # If `buffer` is a np.ndarray, the Tensor will share the underlying
                # storage of the array.
                buffer = convert_to_tensor(value=buffer, dtype=dtype)
            elif not isinstance(buffer, tf.Tensor):
                raise ValueError(
                    'Unexpected type for `buffer` {}. Must be an ndarray,'
                    ' Tensor or np.ndarray.'.format(type(buffer)))

            if shape is not None and tuple(shape) != buffer._shape_tuple():  # pylint: disable=protected-access
                # TODO(srbs): NumPy allows this. Investigate if/how to support this.
                raise ValueError('shape arg must match buffer.shape.')

        assert isinstance(buffer, tf.Tensor)
        if dtype and dtype != buffer.dtype:
            buffer = tf.bitcast(buffer, dtype)
        self._data = buffer
        self.base = None
示例#2
0
def test_seed(hardcoded_seed=None,
              set_eager_seed=True,
              sampler_type='stateful'):
    """Returns a command-line-controllable PRNG seed for unit tests.

  If your test will pass a seed to more than one operation, consider using
  `test_seed_stream` instead.

  When seeding unit-test PRNGs, we want:

  - The seed to be fixed to an arbitrary value most of the time, so the test
    doesn't flake even if its failure probability is noticeable.

  - To switch to different seeds per run when using --runs_per_test to measure
    the test's failure probability.

  - To set the seed to a specific value when reproducing a low-probability event
    (e.g., debugging a crash that only some seeds trigger).

  To those ends, this function returns 17, but respects the command line flags
  `--fixed_seed=<seed>` and `--vary_seed` (Boolean, default False).
  `--vary_seed` uses system entropy to produce unpredictable seeds.
  `--fixed_seed` takes precedence over `--vary_seed` when both are present.

  Note that TensorFlow graph mode operations tend to read seed state from two
  sources: a "graph-level seed" and an "op-level seed".  test_util.TestCase will
  set the former to a fixed value per test, but in general it may be necessary
  to explicitly set both to ensure reproducibility.

  Args:
    hardcoded_seed: Optional Python value.  The seed to use instead of 17 if
      both the `--vary_seed` and `--fixed_seed` flags are unset.  This should
      usually be unnecessary, since a test should pass with any seed.
    set_eager_seed: Python bool.  If true (default), invoke `tf.random.set_seed`
      in Eager mode to get more reproducibility.  Should become unnecessary
      once b/68017812 is resolved.
    sampler_type: 'stateful' or 'stateless'. 'stateless' means we return a seed
      pair.

  Returns:
    seed: 17, unless otherwise specified by arguments or command line flags.
  """
    if flags.FLAGS.fixed_seed is not None:
        answer = int(flags.FLAGS.fixed_seed)
    elif flags.FLAGS.vary_seed:
        entropy = os.urandom(64)
        # Why does Python make it so hard to just grab a bunch of bytes from
        # /dev/urandom and get them interpreted as an integer?  Oh, well.
        if six.PY2:
            answer = int(entropy.encode('hex'), 16)
        else:
            answer = int.from_bytes(entropy, 'big')
        logging.warning('Using seed %s', answer)
    elif hardcoded_seed is not None:
        answer = hardcoded_seed
        if JAX_MODE and np.shape(answer) == (2, ):
            # Workaround for test_seed(hardcoded_seed=test_seed()), which can happen
            # e.g. with the run_test_sample_consistent_log_prob methods above.
            answer = answer[-1]
    else:
        answer = 17
    if sampler_type == 'stateless' or JAX_MODE:
        answer = tf.constant([0, answer % (2**32 - 1)], dtype=tf.uint32)
        if not JAX_MODE:
            answer = tf.bitcast(answer, tf.int32)
    # TODO(b/68017812): Remove this clause once eager correctly supports seeding.
    elif tf.executing_eagerly() and set_eager_seed:
        tf.random.set_seed(answer)
    return answer