def test_get_mask(self): with self.assertRaisesWithPredicateMatch(ValueError, 'OneOf must have a mask'): cost_model_lib.get_mask( schema.OneOf([6, 7, 8, 9], basic_specs.OP_TAG)) mask = cost_model_lib.get_mask( schema.OneOf([6, 7, 8, 9], basic_specs.OP_TAG, tf.constant([0, 0, 1, 0]))) self.assertAllEqual(self.evaluate(mask), [0, 0, 1, 0])
def _maybe_get_mask(value): if isinstance(value, schema.OneOf): return cost_model_lib.get_mask(value) else: return None