def test(self): mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0]) masked_lt = ops.boolean_mask(self.original_lt, mask) golden_lt = core.LabeledTensor( array_ops.boolean_mask(self.original_lt.tensor, mask.tensor), ['x', self.a1, self.a2, self.a3]) self.assertLabeledTensorsEqual(masked_lt, golden_lt)
def test_mismatched_axis(self): mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo']) with self.assertRaisesRegexp(ValueError, 'not equal'): ops.boolean_mask(self.original_lt, mask)
def test_invalid_rank(self): mask = core.LabeledTensor( array_ops.ones((7, 3)) > 3, [self.a0, self.a1]) with self.assertRaises(NotImplementedError): ops.boolean_mask(self.original_lt, mask)
def test_name(self): mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0]) masked_lt = ops.boolean_mask(self.original_lt, mask) self.assertIn('lt_boolean_mask', masked_lt.name)
def test_invalid_rank(self): mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1]) with self.assertRaises(NotImplementedError): ops.boolean_mask(self.original_lt, mask)