def test_softsortlayer(self): direction = 'DESCENDING' layer = layers.SoftSortLayer( axis=self._axis, direction=direction, epsilon=1e-3) outputs = layer(self._inputs) self.assertAllEqual(outputs.shape, self._inputs.shape) sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction) self.assertAllClose(sorted_inputs, outputs, atol=1e-2)
def test_soft_topk_layer(self, topk): direction = 'DESCENDING' layer = layers.SoftSortLayer( axis=self._axis, topk=topk, direction=direction, epsilon=1e-3) outputs = layer(self._inputs) expected_shape = list(self._input_shape) expected_shape[self._axis] = topk self.assertAllEqual(outputs.shape, expected_shape) sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction) self.assertAllClose(sorted_inputs[:, :topk, :], outputs, atol=1e-2)
def test_sortlayer_in_model(self, topk): inputs = tf.random.uniform((32, 10)) outputs = self.take_model_output(layers.SoftSortLayer(topk=topk), inputs) self.assertAllEqual([inputs.shape[0], 1], outputs.shape)