def _max_precision_sum(a, b): """Coerces `a` or `b` to the higher-precision dtype, and returns the sum.""" if not dtype_util.base_equal(a.dtype, b.dtype): if dtype_util.size(a.dtype) >= dtype_util.size(b.dtype): b = tf.cast(b, a.dtype) else: a = tf.cast(a, b.dtype) return a + b
def test_size(self): self.assertEqual(dtype_util.size(tf.int32), 4) self.assertEqual(dtype_util.size(tf.int64), 8) self.assertEqual(dtype_util.size(tf.float32), 4) self.assertEqual(dtype_util.size(tf.float64), 8) self.assertEqual(dtype_util.size(np.int32), 4) self.assertEqual(dtype_util.size(np.int64), 8) self.assertEqual(dtype_util.size(np.float32), 4) self.assertEqual(dtype_util.size(np.float64), 8)
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)]) sample_dtype = tf.int64 if dtype_util.size(self.dtype) > 4 else tf.int32 draws = tf.random.categorical( logits_2d, n, dtype=sample_dtype, seed=seed) draws = tf.cast(draws, self.dtype) return tf.reshape( tf.transpose(draws), shape=tf.concat([[n], self._batch_shape_tensor(logits)], axis=0))
def _sample_n(self, n, seed=None): if tensorshape_util.rank(self.logits.shape) == 2: logits_2d = self.logits else: logits_2d = tf.reshape(self.logits, [-1, self.num_categories]) sample_dtype = tf.int64 if dtype_util.size( self.dtype) > 4 else tf.int32 draws = tf.random.categorical(logits_2d, n, dtype=sample_dtype, seed=seed) draws = tf.reshape(tf.transpose(a=draws), tf.concat([[n], self.batch_shape_tensor()], 0)) return tf.cast(draws, self.dtype)
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)]) sample_dtype = tf.int64 if dtype_util.size(self.dtype) > 4 else tf.int32 # TODO(b/147874898): Remove workaround for seed-sensitive tests. if seed is None or isinstance(seed, six.integer_types): draws = tf.random.categorical( logits_2d, n, dtype=sample_dtype, seed=seed) else: draws = samplers.categorical( logits_2d, n, dtype=sample_dtype, seed=seed) draws = tf.cast(draws, self.dtype) return tf.reshape( tf.transpose(draws), shape=ps.concat([[n], self._batch_shape_tensor(logits=logits)], axis=0))