示例#1
0
 def _apply_fn(dataset):
   output_shapes = _merge_output_shapes(
       dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
   # pylint: disable=protected-access
   return dataset_ops._RestructuredDataset(
       dataset.map(_check_shape),
       dataset_ops.get_legacy_output_types(dataset),
       output_shapes=output_shapes,
       output_classes=dataset_ops.get_legacy_output_classes(dataset))
示例#2
0
    def testIncorrectPythonStructure(self):
        # Tests that an exception is raised (as opposed to a segfault) when the
        # Python structure assigned to a dataset is incorrect.
        dataset = dataset_ops.Dataset.range(10)
        spec = tensor_spec.TensorSpec([], dtypes.int64)
        new_structure = (spec, spec)
        dataset = dataset_ops._RestructuredDataset(dataset, new_structure)
        dataset = dataset.map(lambda x, y: y)

        with self.assertRaisesOpError(""):
            self.getDatasetOutput(dataset)
    def testRestructureDataset(self):
        components = (array_ops.placeholder(dtypes.int32),
                      (array_ops.placeholder(dtypes.int32, shape=[None]),
                       array_ops.placeholder(dtypes.int32, shape=[20, 30])))
        dataset = dataset_ops.Dataset.from_tensors(components)

        i32 = dtypes.int32

        test_cases = [((i32, i32, i32), None), (((i32, i32), i32), None),
                      ((i32, i32, i32), (None, None, None)),
                      ((i32, i32, i32), ([17], [17], [20, 30]))]

        for new_types, new_shape_lists in test_cases:
            # pylint: disable=protected-access
            new = dataset_ops._RestructuredDataset(dataset, new_types,
                                                   new_shape_lists)
            # pylint: enable=protected-access
            self.assertEqual(new_types,
                             dataset_ops.get_legacy_output_types(new))
            if new_shape_lists is not None:
                for expected_shape_list, shape in zip(
                        nest.flatten(new_shape_lists),
                        nest.flatten(
                            dataset_ops.get_legacy_output_shapes(new))):
                    if expected_shape_list is None:
                        self.assertIs(None, shape.ndims)
                    else:
                        self.assertEqual(expected_shape_list, shape.as_list())

        fail_cases = [((i32, dtypes.int64, i32), None),
                      ((i32, i32, i32, i32), None),
                      ((i32, i32, i32), ((None, None), None)),
                      ((i32, i32, i32), (None, None, None, None)),
                      ((i32, i32, i32), (None, [None], [21, 30]))]

        for new_types, new_shape_lists in fail_cases:
            with self.assertRaises(ValueError):
                # pylint: disable=protected-access
                new = dataset_ops._RestructuredDataset(dataset, new_types,
                                                       new_shape_lists)