def test_sortquantile_shape(self): axis = 1 x = tf.random.uniform((3, 20, 4), dtype=tf.float32) soft_q = ops.softquantiles(x, 0.3, 0.03, axis=axis) self.assertEqual(soft_q.shape, (3, 4)) soft_q2 = ops.softquantiles(x, 0.3, 0.03, axis=axis, may_squeeze=False) self.assertEqual(soft_q2.shape, (3, 1, 4))
def call(self, inputs): # For some reason, in graph mode the reshape is necessary. return tf.reshape( ops.softquantiles(inputs, self._quantiles, axis=self._axis, **self._kwargs), self.get_output_shape(tf.shape(inputs)))
def call(self, y_true, y_pred): error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power) width = self._end_quantile - self._start_quantile quantile = 0.5 * (self._end_quantile + self._start_quantile) return ops.softquantiles(error, quantile, quantile_width=width, axis=0, **self._kwargs)
def test_softquantiles(self): num_points = 19 sorted_values = tf.range(0, num_points, dtype=tf.float32) x = tf.random.shuffle(sorted_values) target_quantiles = [0.25, 0.50, 0.75] target_indices = [4, 9, 14] soft_q = ops.softquantiles(x, target_quantiles, epsilon=1e-3) hard_q = tf.gather(sorted_values, target_indices) self.assertAllClose(soft_q, hard_q, 0.2, 0.2)
def test_softquantile(self, quantile): # Builds the input vector so that the desired quantile always corresponds to # an exact integer index. num_points_before_quantile = 10 step = quantile / num_points_before_quantile num_points = int(1.0 / step + 1.0) quantile_width = step axis = 1 x = tf.random.uniform((3, num_points, 4), dtype=tf.float32) soft_q = ops.softquantiles( x, quantile, quantile_width, axis=axis, epsilon=1e-3) # Compare to getting the exact quantile. hard_q = tf.gather( tf.sort(x, direction='ASCENDING', axis=axis), int(quantile * num_points), axis=1) self.assertAllClose(soft_q, hard_q, 0.2, 0.2)
def call(self, y_true, y_pred): error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power) return ops.softquantiles(error, self._quantile, axis=0, **self._kwargs)
def test_softquantile_errors(self, q1, q2, width): x = tf.random.uniform((3, 10)) with self.assertRaises(tf.errors.InvalidArgumentError): ops.softquantiles(x, [q1, q2], quantile_width=width, axis=-1)