Beispiel #1
0
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None):
    """Outputs random values from a truncated normal distribution.

  The generated values follow a normal distribution with specified mean and
  standard deviation, except that values whose magnitude is more than 2 standard
  deviations from the mean are dropped and re-picked.

  Args:
    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
      truncated normal distribution.
    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
      of the truncated normal distribution.
    dtype: The type of the output.
    seed: A Python integer. Used to create a random seed for the distribution.
      See
      [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
      for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of the specified shape filled with random truncated normal values.
  """
    with ops.name_scope(name, "truncated_normal", [shape, mean, stddev]) as name:
        shape_tensor = _ShapeTensor(shape)
        mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
        stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
        seed1, seed2 = random_seed.get_seed(seed)
        rnd = gen_random_ops._truncated_normal(shape_tensor, dtype, seed=seed1, seed2=seed2)
        mul = rnd * stddev_tensor
        value = math_ops.add(mul, mean_tensor, name=name)
        return value
Beispiel #2
0
def truncated_normal(shape,
                     mean=0.0,
                     stddev=1.0,
                     dtype=dtypes.float32,
                     seed=None,
                     name=None):
    """Outputs random values from a truncated normal distribution.

  The generated values follow a normal distribution with specified mean and
  standard deviation, except that values whose magnitude is more than 2 standard
  deviations from the mean are dropped and re-picked.

  Args:
    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
      truncated normal distribution.
    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
      of the normal distribution, before truncation.
    dtype: The type of the output.
    seed: A Python integer. Used to create a random seed for the distribution.
      See
      @{tf.set_random_seed}
      for behavior.
    name: A name for the operation (optional).

  Returns:
    A tensor of the specified shape filled with random truncated normal values.
  """
    with ops.name_scope(name, "truncated_normal",
                        [shape, mean, stddev]) as name:
        shape_tensor = _ShapeTensor(shape)
        mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
        stddev_tensor = ops.convert_to_tensor(stddev,
                                              dtype=dtype,
                                              name="stddev")
        seed1, seed2 = random_seed.get_seed(seed)
        rnd = gen_random_ops._truncated_normal(shape_tensor,
                                               dtype,
                                               seed=seed1,
                                               seed2=seed2)
        mul = rnd * stddev_tensor
        value = math_ops.add(mul, mean_tensor, name=name)
        return value