def testErrors(self):
    with self.assertRaisesRegexp(ValueError,
                                 r"vector of length `len\(datasets\)`"):
      interleave_ops.sample_from_datasets(
          [dataset_ops.Dataset.range(10),
           dataset_ops.Dataset.range(20)],
          weights=[0.25, 0.25, 0.25, 0.25])

    with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
      interleave_ops.sample_from_datasets(
          [dataset_ops.Dataset.range(10),
           dataset_ops.Dataset.range(20)],
          weights=[1, 1])

    with self.assertRaisesRegexp(TypeError, "must have the same type"):
      interleave_ops.sample_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(0.0)
      ])

    with self.assertRaisesRegexp(TypeError, "tf.int64"):
      interleave_ops.choose_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(1)
      ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))

    with self.assertRaisesRegexp(TypeError, "scalar"):
      interleave_ops.choose_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(1)
      ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
  def testErrors(self):
    with self.assertRaisesRegexp(ValueError,
                                 r"vector of length `len\(datasets\)`"):
      interleave_ops.sample_from_datasets(
          [dataset_ops.Dataset.range(10),
           dataset_ops.Dataset.range(20)],
          weights=[0.25, 0.25, 0.25, 0.25])

    with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
      interleave_ops.sample_from_datasets(
          [dataset_ops.Dataset.range(10),
           dataset_ops.Dataset.range(20)],
          weights=[1, 1])

    with self.assertRaisesRegexp(TypeError, "must have the same type"):
      interleave_ops.sample_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(0.0)
      ])

    with self.assertRaisesRegexp(TypeError, "tf.int64"):
      interleave_ops.choose_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(1)
      ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))

    with self.assertRaisesRegexp(TypeError, "scalar"):
      interleave_ops.choose_from_datasets([
          dataset_ops.Dataset.from_tensors(0),
          dataset_ops.Dataset.from_tensors(1)
      ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
  def testSelectFromDatasets(self):
    words = [b"foo", b"bar", b"baz"]
    datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
    choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
    dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    with self.test_session() as sess:
      for i in choice_array:
        self.assertEqual(words[i], sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
  def testSelectFromDatasets(self):
    words = [b"foo", b"bar", b"baz"]
    datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
    choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
    dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      for i in choice_array:
        self.assertEqual(words[i], sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)