コード例 #1
0
 def test_softsortlayer(self):
   direction = 'DESCENDING'
   layer = layers.SoftSortLayer(
       axis=self._axis, direction=direction, epsilon=1e-3)
   outputs = layer(self._inputs)
   self.assertAllEqual(outputs.shape, self._inputs.shape)
   sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)
   self.assertAllClose(sorted_inputs, outputs, atol=1e-2)
コード例 #2
0
 def test_soft_topk_layer(self, topk):
   direction = 'DESCENDING'
   layer = layers.SoftSortLayer(
       axis=self._axis, topk=topk, direction=direction, epsilon=1e-3)
   outputs = layer(self._inputs)
   expected_shape = list(self._input_shape)
   expected_shape[self._axis] = topk
   self.assertAllEqual(outputs.shape, expected_shape)
   sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)
   self.assertAllClose(sorted_inputs[:, :topk, :], outputs, atol=1e-2)
コード例 #3
0
 def test_sortlayer_in_model(self, topk):
     inputs = tf.random.uniform((32, 10))
     outputs = self.take_model_output(layers.SoftSortLayer(topk=topk),
                                      inputs)
     self.assertAllEqual([inputs.shape[0], 1], outputs.shape)