Exemplo n.º 1
0
  def test(self):
    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    x = core.LabeledTensor(array_ops.ones(5), ['x'])
    y = core.LabeledTensor(array_ops.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
Exemplo n.º 2
0
  def test(self):
    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    x = core.LabeledTensor(array_ops.ones(5), ['x'])
    y = core.LabeledTensor(array_ops.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
Exemplo n.º 3
0
  def test(self):
    condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
    x = core.LabeledTensor(tf.ones(5), ['x'])
    y = core.LabeledTensor(tf.zeros(5), ['x'])
    where_lt = ops.where(condition, x, y)

    golden_lt = core.LabeledTensor(
        tf.concat_v2([tf.ones(3), tf.zeros(2)], 0), ['x'])
    self.assertLabeledTensorsEqual(where_lt, golden_lt)
Exemplo n.º 4
0
 def test_mismatched_axes(self):
     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
     with self.assertRaisesRegexp(ValueError, 'equal axes'):
         ops.where(condition, condition[:3], condition)
     with self.assertRaisesRegexp(ValueError, 'equal axes'):
         ops.where(condition, condition, condition[:3])
Exemplo n.º 5
0
 def test_name(self):
     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
     where_lt = ops.where(condition, condition, condition)
     self.assertIn('lt_where', where_lt.name)
Exemplo n.º 6
0
 def test_mismatched_axes(self):
   condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
   with self.assertRaisesRegexp(ValueError, 'equal axes'):
     ops.where(condition, condition[:3], condition)
   with self.assertRaisesRegexp(ValueError, 'equal axes'):
     ops.where(condition, condition, condition[:3])
Exemplo n.º 7
0
 def test_name(self):
   condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
   where_lt = ops.where(condition, condition, condition)
   self.assertIn('lt_where', where_lt.name)