def _apply_fn(dataset): output_shapes = _merge_output_shapes( dataset_ops.get_legacy_output_shapes(dataset), expected_shapes) # pylint: disable=protected-access return dataset_ops._RestructuredDataset( dataset.map(_check_shape), dataset_ops.get_legacy_output_types(dataset), output_shapes=output_shapes, output_classes=dataset_ops.get_legacy_output_classes(dataset))
def testIncorrectPythonStructure(self): # Tests that an exception is raised (as opposed to a segfault) when the # Python structure assigned to a dataset is incorrect. dataset = dataset_ops.Dataset.range(10) spec = tensor_spec.TensorSpec([], dtypes.int64) new_structure = (spec, spec) dataset = dataset_ops._RestructuredDataset(dataset, new_structure) dataset = dataset.map(lambda x, y: y) with self.assertRaisesOpError(""): self.getDatasetOutput(dataset)
def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(dtypes.int32, shape=[None]), array_ops.placeholder(dtypes.int32, shape=[20, 30]))) dataset = dataset_ops.Dataset.from_tensors(components) i32 = dtypes.int32 test_cases = [((i32, i32, i32), None), (((i32, i32), i32), None), ((i32, i32, i32), (None, None, None)), ((i32, i32, i32), ([17], [17], [20, 30]))] for new_types, new_shape_lists in test_cases: # pylint: disable=protected-access new = dataset_ops._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new)) if new_shape_lists is not None: for expected_shape_list, shape in zip( nest.flatten(new_shape_lists), nest.flatten( dataset_ops.get_legacy_output_shapes(new))): if expected_shape_list is None: self.assertIs(None, shape.ndims) else: self.assertEqual(expected_shape_list, shape.as_list()) fail_cases = [((i32, dtypes.int64, i32), None), ((i32, i32, i32, i32), None), ((i32, i32, i32), ((None, None), None)), ((i32, i32, i32), (None, None, None, None)), ((i32, i32, i32), (None, [None], [21, 30]))] for new_types, new_shape_lists in fail_cases: with self.assertRaises(ValueError): # pylint: disable=protected-access new = dataset_ops._RestructuredDataset(dataset, new_types, new_shape_lists)