예제 #1
0
  def test(self):
    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    x = core.LabeledTensor(array_ops.ones(5), ['x'])
    y = core.LabeledTensor(array_ops.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
예제 #2
0
  def test(self):
    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    x = core.LabeledTensor(array_ops.ones(5), ['x'])
    y = core.LabeledTensor(array_ops.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
예제 #3
0
  def test(self):
    condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
    x = core.LabeledTensor(tf.ones(5), ['x'])
    y = core.LabeledTensor(tf.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        tf.concat_v2([tf.ones(3), tf.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
예제 #4
0
 def test_mismatched_axes(self):
     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
     with self.assertRaisesRegexp(ValueError, 'equal axes'):
         ops.where(condition, condition[:3], condition)
     with self.assertRaisesRegexp(ValueError, 'equal axes'):
         ops.where(condition, condition, condition[:3])
예제 #5
0
 def test_name(self):
     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
     where_lt = ops.where(condition, condition, condition)
     self.assertIn('lt_where', where_lt.name)
예제 #6
0
 def test_mismatched_axes(self):
   condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
   with self.assertRaisesRegexp(ValueError, 'equal axes'):
     ops.where(condition, condition[:3], condition)
   with self.assertRaisesRegexp(ValueError, 'equal axes'):
     ops.where(condition, condition, condition[:3])
예제 #7
0
 def test_name(self):
   condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
   where_lt = ops.where(condition, condition, condition)
   self.assertIn('lt_where', where_lt.name)