示例#1
0
    def test_sortquantile_shape(self):
        axis = 1
        x = tf.random.uniform((3, 20, 4), dtype=tf.float32)
        soft_q = ops.softquantiles(x, 0.3, 0.03, axis=axis)
        self.assertEqual(soft_q.shape, (3, 4))

        soft_q2 = ops.softquantiles(x, 0.3, 0.03, axis=axis, may_squeeze=False)
        self.assertEqual(soft_q2.shape, (3, 1, 4))
示例#2
0
 def call(self, inputs):
     # For some reason, in graph mode the reshape is necessary.
     return tf.reshape(
         ops.softquantiles(inputs,
                           self._quantiles,
                           axis=self._axis,
                           **self._kwargs),
         self.get_output_shape(tf.shape(inputs)))
示例#3
0
 def call(self, y_true, y_pred):
     error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power)
     width = self._end_quantile - self._start_quantile
     quantile = 0.5 * (self._end_quantile + self._start_quantile)
     return ops.softquantiles(error,
                              quantile,
                              quantile_width=width,
                              axis=0,
                              **self._kwargs)
示例#4
0
    def test_softquantiles(self):
        num_points = 19
        sorted_values = tf.range(0, num_points, dtype=tf.float32)
        x = tf.random.shuffle(sorted_values)

        target_quantiles = [0.25, 0.50, 0.75]
        target_indices = [4, 9, 14]
        soft_q = ops.softquantiles(x, target_quantiles, epsilon=1e-3)
        hard_q = tf.gather(sorted_values, target_indices)

        self.assertAllClose(soft_q, hard_q, 0.2, 0.2)
示例#5
0
  def test_softquantile(self, quantile):
    # Builds the input vector so that the desired quantile always corresponds to
    # an exact integer index.
    num_points_before_quantile = 10
    step = quantile / num_points_before_quantile
    num_points = int(1.0 / step + 1.0)
    quantile_width = step

    axis = 1
    x = tf.random.uniform((3, num_points, 4), dtype=tf.float32)
    soft_q = ops.softquantiles(
        x, quantile, quantile_width, axis=axis, epsilon=1e-3)

    # Compare to getting the exact quantile.
    hard_q = tf.gather(
        tf.sort(x, direction='ASCENDING', axis=axis),
        int(quantile * num_points), axis=1)

    self.assertAllClose(soft_q, hard_q, 0.2, 0.2)
示例#6
0
 def call(self, y_true, y_pred):
     error = tf.pow(tf.abs(tf.squeeze(y_pred) - y_true), self._power)
     return ops.softquantiles(error, self._quantile, axis=0, **self._kwargs)
示例#7
0
 def test_softquantile_errors(self, q1, q2, width):
     x = tf.random.uniform((3, 10))
     with self.assertRaises(tf.errors.InvalidArgumentError):
         ops.softquantiles(x, [q1, q2], quantile_width=width, axis=-1)