def testGetGroundTruthIndices(self): """Test get_target_indices function with AdvTargetType.GROUND_TRUTH.""" logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype='float32') labels = tf.constant([2, 1], dtype='int32') adv_target_config = configs.AdvTargetConfig( target_method=configs.AdvTargetType.GROUND_TRUTH) self.assertAllEqual( tf.constant([2, 1], dtype='int32'), self.evaluate( utils.get_target_indices(logits, labels, adv_target_config)))
def testGetRandomIndices(self): """Test get_target_indices function with AdvTargetType.RANDOM.""" logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype='float32') labels = tf.constant([2, 1], dtype='int32') adv_target_config = configs.AdvTargetConfig( target_method=configs.AdvTargetType.RANDOM, random_seed=1) self.assertAllEqual( tf.constant([0, 2], dtype='int32'), self.evaluate( utils.get_target_indices(logits, labels, adv_target_config)))
def testGetLeastIndices(self): """Test get_target_indices function with AdvTargetType.LEAST.""" logits = tf.constant([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype='float32') labels = tf.constant([2, 1], dtype='int32') adv_target_config = configs.AdvTargetConfig( target_method=configs.AdvTargetType.LEAST) with self.cached_session() as sess: self.assertAllEqual( tf.constant([0, 2], dtype='int32'), sess.run( utils.get_target_indices(logits, labels, adv_target_config)))