Пример #1
0
def _max_precision_sum(a, b):
  """Coerces `a` or `b` to the higher-precision dtype, and returns the sum."""
  if not dtype_util.base_equal(a.dtype, b.dtype):
    if dtype_util.size(a.dtype) >= dtype_util.size(b.dtype):
      b = tf.cast(b, a.dtype)
    else:
      a = tf.cast(a, b.dtype)
  return a + b
Пример #2
0
  def test_size(self):
    self.assertEqual(dtype_util.size(tf.int32), 4)
    self.assertEqual(dtype_util.size(tf.int64), 8)
    self.assertEqual(dtype_util.size(tf.float32), 4)
    self.assertEqual(dtype_util.size(tf.float64), 8)

    self.assertEqual(dtype_util.size(np.int32), 4)
    self.assertEqual(dtype_util.size(np.int64), 8)
    self.assertEqual(dtype_util.size(np.float32), 4)
    self.assertEqual(dtype_util.size(np.float64), 8)
Пример #3
0
 def _sample_n(self, n, seed=None):
   logits = self._logits_parameter_no_checks()
   logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
   sample_dtype = tf.int64 if dtype_util.size(self.dtype) > 4 else tf.int32
   draws = tf.random.categorical(
       logits_2d, n, dtype=sample_dtype, seed=seed)
   draws = tf.cast(draws, self.dtype)
   return tf.reshape(
       tf.transpose(draws),
       shape=tf.concat([[n], self._batch_shape_tensor(logits)], axis=0))
Пример #4
0
 def _sample_n(self, n, seed=None):
     if tensorshape_util.rank(self.logits.shape) == 2:
         logits_2d = self.logits
     else:
         logits_2d = tf.reshape(self.logits, [-1, self.num_categories])
     sample_dtype = tf.int64 if dtype_util.size(
         self.dtype) > 4 else tf.int32
     draws = tf.random.categorical(logits_2d,
                                   n,
                                   dtype=sample_dtype,
                                   seed=seed)
     draws = tf.reshape(tf.transpose(a=draws),
                        tf.concat([[n], self.batch_shape_tensor()], 0))
     return tf.cast(draws, self.dtype)
Пример #5
0
 def _sample_n(self, n, seed=None):
   logits = self._logits_parameter_no_checks()
   logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
   sample_dtype = tf.int64 if dtype_util.size(self.dtype) > 4 else tf.int32
   # TODO(b/147874898): Remove workaround for seed-sensitive tests.
   if seed is None or isinstance(seed, six.integer_types):
     draws = tf.random.categorical(
         logits_2d, n, dtype=sample_dtype, seed=seed)
   else:
     draws = samplers.categorical(
         logits_2d, n, dtype=sample_dtype, seed=seed)
   draws = tf.cast(draws, self.dtype)
   return tf.reshape(
       tf.transpose(draws),
       shape=ps.concat([[n], self._batch_shape_tensor(logits=logits)], axis=0))