def testErrors(self): with self.assertRaisesRegexp(ValueError, r"vector of length `len\(datasets\)`"): interleave_ops.sample_from_datasets( [dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(20)], weights=[0.25, 0.25, 0.25, 0.25]) with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): interleave_ops.sample_from_datasets( [dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(20)], weights=[1, 1]) with self.assertRaisesRegexp(TypeError, "must have the same type"): interleave_ops.sample_from_datasets([ dataset_ops.Dataset.from_tensors(0), dataset_ops.Dataset.from_tensors(0.0) ]) with self.assertRaisesRegexp(TypeError, "tf.int64"): interleave_ops.choose_from_datasets([ dataset_ops.Dataset.from_tensors(0), dataset_ops.Dataset.from_tensors(1) ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) with self.assertRaisesRegexp(TypeError, "scalar"): interleave_ops.choose_from_datasets([ dataset_ops.Dataset.from_tensors(0), dataset_ops.Dataset.from_tensors(1) ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
def testSelectFromDatasets(self): words = [b"foo", b"bar", b"baz"] datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] choice_array = np.random.randint(3, size=(15,), dtype=np.int64) choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.test_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testSelectFromDatasets(self): words = [b"foo", b"bar", b"baz"] datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] choice_array = np.random.randint(3, size=(15,), dtype=np.int64) choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)