def test_quantilelayer_in_model(self):
   inputs = tf.random.uniform((32, 10))
   outputs = self.take_model_output(
       layers.SoftQuantilesLayer(
           quantiles=[0.2, 0.5, 0.8], output_shape=(32, 3)),
       inputs)
   self.assertAllEqual([inputs.shape[0], 1], outputs.shape)
  def test_softquantiles(self):
    inputs = tf.reshape(tf.range(101, dtype=tf.float32), (1, -1))
    axis = 1
    quantiles = [0.25, 0.50, 0.75]
    layer = layers.SoftQuantilesLayer(
        quantiles=quantiles, output_shape=None, axis=axis, epsilon=1e-3)

    outputs = layer(inputs)
    self.assertAllEqual(outputs.shape, (1, 3))

    self.assertAllClose(tf.constant([[25., 50., 75.]]), outputs, atol=0.5)