def test(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) x = core.LabeledTensor(array_ops.ones(5), ['x']) y = core.LabeledTensor(array_ops.zeros(5), ['x']) where_lt = ops.where(condition, x, y) golden_lt = core.LabeledTensor( array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x']) self.assertLabeledTensorsEqual(where_lt, golden_lt)
def test(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) x = core.LabeledTensor(array_ops.ones(5), ['x']) y = core.LabeledTensor(array_ops.zeros(5), ['x']) where_lt = ops.where(condition, x, y) golden_lt = core.LabeledTensor( array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x']) self.assertLabeledTensorsEqual(where_lt, golden_lt)
def test(self): condition = core.LabeledTensor(tf.range(5) < 3, ['x']) x = core.LabeledTensor(tf.ones(5), ['x']) y = core.LabeledTensor(tf.zeros(5), ['x']) where_lt = ops.where(condition, x, y) golden_lt = core.LabeledTensor( tf.concat_v2([tf.ones(3), tf.zeros(2)], 0), ['x']) self.assertLabeledTensorsEqual(where_lt, golden_lt)
def test_mismatched_axes(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) with self.assertRaisesRegexp(ValueError, 'equal axes'): ops.where(condition, condition[:3], condition) with self.assertRaisesRegexp(ValueError, 'equal axes'): ops.where(condition, condition, condition[:3])
def test_name(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) where_lt = ops.where(condition, condition, condition) self.assertIn('lt_where', where_lt.name)
def test_mismatched_axes(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) with self.assertRaisesRegexp(ValueError, 'equal axes'): ops.where(condition, condition[:3], condition) with self.assertRaisesRegexp(ValueError, 'equal axes'): ops.where(condition, condition, condition[:3])
def test_name(self): condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) where_lt = ops.where(condition, condition, condition) self.assertIn('lt_where', where_lt.name)