Esempio n. 1
0
 def testGetGroundTruthIndices(self):
   """Test get_target_indices function with AdvTargetType.GROUND_TRUTH."""
   logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype='float32')
   labels = tf.constant([2, 1], dtype='int32')
   adv_target_config = configs.AdvTargetConfig(
       target_method=configs.AdvTargetType.GROUND_TRUTH)
   self.assertAllEqual(
       tf.constant([2, 1], dtype='int32'),
       self.evaluate(
           utils.get_target_indices(logits, labels, adv_target_config)))
Esempio n. 2
0
 def testGetRandomIndices(self):
   """Test get_target_indices function with AdvTargetType.RANDOM."""
   logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype='float32')
   labels = tf.constant([2, 1], dtype='int32')
   adv_target_config = configs.AdvTargetConfig(
       target_method=configs.AdvTargetType.RANDOM, random_seed=1)
   self.assertAllEqual(
       tf.constant([0, 2], dtype='int32'),
       self.evaluate(
           utils.get_target_indices(logits, labels, adv_target_config)))
 def testGetLeastIndices(self):
     """Test get_target_indices function with AdvTargetType.LEAST."""
     logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]],
                          dtype='float32')
     labels = tf.constant([2, 1], dtype='int32')
     adv_target_config = configs.AdvTargetConfig(
         target_method=configs.AdvTargetType.LEAST)
     with self.cached_session() as sess:
         self.assertAllEqual(
             tf.constant([0, 2], dtype='int32'),
             sess.run(
                 utils.get_target_indices(logits, labels,
                                          adv_target_config)))