Exemplo n.º 1
0
 def test(self):
     mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
     masked_lt = ops.boolean_mask(self.original_lt, mask)
     golden_lt = core.LabeledTensor(
         array_ops.boolean_mask(self.original_lt.tensor, mask.tensor),
         ['x', self.a1, self.a2, self.a3])
     self.assertLabeledTensorsEqual(masked_lt, golden_lt)
Exemplo n.º 2
0
 def test(self):
   mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
   masked_lt = ops.boolean_mask(self.original_lt, mask)
   golden_lt = core.LabeledTensor(
       array_ops.boolean_mask(self.original_lt.tensor, mask.tensor),
       ['x', self.a1, self.a2, self.a3])
   self.assertLabeledTensorsEqual(masked_lt, golden_lt)
Exemplo n.º 3
0
 def test_mismatched_axis(self):
     mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo'])
     with self.assertRaisesRegexp(ValueError, 'not equal'):
         ops.boolean_mask(self.original_lt, mask)
Exemplo n.º 4
0
 def test_invalid_rank(self):
     mask = core.LabeledTensor(
         array_ops.ones((7, 3)) > 3, [self.a0, self.a1])
     with self.assertRaises(NotImplementedError):
         ops.boolean_mask(self.original_lt, mask)
Exemplo n.º 5
0
 def test_name(self):
     mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
     masked_lt = ops.boolean_mask(self.original_lt, mask)
     self.assertIn('lt_boolean_mask', masked_lt.name)
Exemplo n.º 6
0
 def test_mismatched_axis(self):
   mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo'])
   with self.assertRaisesRegexp(ValueError, 'not equal'):
     ops.boolean_mask(self.original_lt, mask)
Exemplo n.º 7
0
 def test_invalid_rank(self):
   mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1])
   with self.assertRaises(NotImplementedError):
     ops.boolean_mask(self.original_lt, mask)
Exemplo n.º 8
0
 def test_name(self):
   mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
   masked_lt = ops.boolean_mask(self.original_lt, mask)
   self.assertIn('lt_boolean_mask', masked_lt.name)