예제 #1
0
 def _apply_fn(dataset):
     output_shapes = _merge_output_shapes(
         dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
     # pylint: disable=protected-access
     return batching._RestructuredDataset(
         dataset.map(_check_shape),
         dataset_ops.get_legacy_output_types(dataset),
         output_shapes=output_shapes,
         output_classes=dataset_ops.get_legacy_output_classes(dataset))
예제 #2
0
 def _apply_fn(dataset):
   output_shapes = _merge_output_shapes(
       dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
   # pylint: disable=protected-access
   return batching._RestructuredDataset(
       dataset.map(_check_shape),
       dataset_ops.get_legacy_output_types(dataset),
       output_shapes=output_shapes,
       output_classes=dataset_ops.get_legacy_output_classes(dataset))
예제 #3
0
    def testRestructureDataset(self):
        components = (array_ops.placeholder(dtypes.int32),
                      (array_ops.placeholder(dtypes.int32, shape=[None]),
                       array_ops.placeholder(dtypes.int32, shape=[20, 30])))
        dataset = dataset_ops.Dataset.from_tensors(components)

        i32 = dtypes.int32

        test_cases = [((i32, i32, i32), None), (((i32, i32), i32), None),
                      ((i32, i32, i32), (None, None, None)),
                      ((i32, i32, i32), ([17], [17], [20, 30]))]

        for new_types, new_shape_lists in test_cases:
            # pylint: disable=protected-access
            new = batching._RestructuredDataset(dataset, new_types,
                                                new_shape_lists)
            # pylint: enable=protected-access
            self.assertEqual(new_types,
                             dataset_ops.get_legacy_output_types(new))
            if new_shape_lists is not None:
                for expected_shape_list, shape in zip(
                        nest.flatten(new_shape_lists),
                        nest.flatten(
                            dataset_ops.get_legacy_output_shapes(new))):
                    if expected_shape_list is None:
                        self.assertIs(None, shape.ndims)
                    else:
                        self.assertEqual(expected_shape_list, shape.as_list())

        fail_cases = [((i32, dtypes.int64, i32), None),
                      ((i32, i32, i32, i32), None),
                      ((i32, i32, i32), ((None, None), None)),
                      ((i32, i32, i32), (None, None, None, None)),
                      ((i32, i32, i32), (None, [None], [21, 30]))]

        for new_types, new_shape_lists in fail_cases:
            with self.assertRaises(ValueError):
                # pylint: disable=protected-access
                new = batching._RestructuredDataset(dataset, new_types,
                                                    new_shape_lists)
  def testRestructureDataset(self):
    components = (array_ops.placeholder(dtypes.int32),
                  (array_ops.placeholder(dtypes.int32, shape=[None]),
                   array_ops.placeholder(dtypes.int32, shape=[20, 30])))
    dataset = dataset_ops.Dataset.from_tensors(components)

    i32 = dtypes.int32

    test_cases = [((i32, i32, i32), None),
                  (((i32, i32), i32), None),
                  ((i32, i32, i32), (None, None, None)),
                  ((i32, i32, i32), ([17], [17], [20, 30]))]

    for new_types, new_shape_lists in test_cases:
      # pylint: disable=protected-access
      new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
      # pylint: enable=protected-access
      self.assertEqual(new_types, dataset_ops.get_legacy_output_types(new))
      if new_shape_lists is not None:
        for expected_shape_list, shape in zip(
            nest.flatten(new_shape_lists),
            nest.flatten(dataset_ops.get_legacy_output_shapes(new))):
          if expected_shape_list is None:
            self.assertIs(None, shape.ndims)
          else:
            self.assertEqual(expected_shape_list, shape.as_list())

    fail_cases = [((i32, dtypes.int64, i32), None),
                  ((i32, i32, i32, i32), None),
                  ((i32, i32, i32), ((None, None), None)),
                  ((i32, i32, i32), (None, None, None, None)),
                  ((i32, i32, i32), (None, [None], [21, 30]))]

    for new_types, new_shape_lists in fail_cases:
      with self.assertRaises(ValueError):
        # pylint: disable=protected-access
        new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)