Example #1
0
    def __init__(self, elements, name=None):
        if not elements:
            raise ValueError(
                "Invalid `elements`. `elements` should not be empty.")
        if not isinstance(elements, list):
            raise ValueError("Invalid `elements`. `elements` must be a list.")

        elements = [
            structure.normalize_element(element) for element in elements
        ]
        type_specs = [
            structure.type_spec_from_value(element) for element in elements
        ]

        # Check that elements have same nested structure.
        num_elements = len(elements)
        for i in range(1, num_elements):
            nest.assert_same_structure(type_specs[0], type_specs[i])

        # Infer elements' supershape.
        flattened_type_specs = [
            nest.flatten(type_spec) for type_spec in type_specs
        ]
        num_tensors_per_element = len(flattened_type_specs[0])
        flattened_structure = [None] * num_tensors_per_element
        for i in range(num_tensors_per_element):
            flattened_structure[i] = flattened_type_specs[0][i]
            for j in range(1, num_elements):
                flattened_structure[i] = flattened_structure[
                    i].most_specific_common_supertype(
                        [flattened_type_specs[j][i]])

        if not isinstance(type_specs[0], dataset_ops.DatasetSpec):
            self._tensors = list(
                itertools.chain.from_iterable(
                    [nest.flatten(element) for element in elements]))
        else:
            self._tensors = [x._variant_tensor for x in elements]
        self._structure = nest.pack_sequence_as(type_specs[0],
                                                flattened_structure)
        self._name = name
        variant_tensor = gen_experimental_dataset_ops.list_dataset(
            self._tensors,
            output_types=self._flat_types,
            output_shapes=self._flat_shapes,
            metadata=self._metadata.SerializeToString())
        super(_ListDataset, self).__init__(variant_tensor)
Example #2
0
  def __init__(self,
               input_dataset,
               initial_state,
               scan_func,
               use_default_device=None):
    """See `scan()` for details."""
    self._input_dataset = input_dataset
    self._initial_state = structure.normalize_element(initial_state)

    # Compute initial values for the state classes, shapes and types based on
    # the initial state. The shapes may be refined by running `tf_scan_func` one
    # or more times below.
    self._state_structure = structure.type_spec_from_value(self._initial_state)

    # Iteratively rerun the scan function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = dataset_ops.StructuredFunctionWrapper(
          scan_func,
          self._transformation_name(),
          input_structure=(self._state_structure,
                           input_dataset.element_spec),
          add_to_graph=False)
      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
              and len(wrapped_func.output_types) == 2):
        raise TypeError("The scan function must return a pair comprising the "
                        "new state and the output value.")

      new_state_classes, self._output_classes = wrapped_func.output_classes

      # Extract and validate class information from the returned values.
      new_state_classes, output_classes = wrapped_func.output_classes
      old_state_classes = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
          self._state_structure)
      for new_state_class, old_state_class in zip(
          nest.flatten(new_state_classes),
          nest.flatten(old_state_classes)):
        if not issubclass(new_state_class, old_state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_classes, new_state_classes))

      # Extract and validate type information from the returned values.
      new_state_types, output_types = wrapped_func.output_types
      old_state_types = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
          self._state_structure)
      for new_state_type, old_state_type in zip(
          nest.flatten(new_state_types), nest.flatten(old_state_types)):
        if new_state_type != old_state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_types, new_state_types))

      # Extract shape information from the returned values.
      new_state_shapes, output_shapes = wrapped_func.output_shapes
      old_state_shapes = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
          self._state_structure)
      self._element_spec = structure.convert_legacy_structure(
          output_types, output_shapes, output_classes)

      flat_state_shapes = nest.flatten(old_state_shapes)
      flat_new_state_shapes = nest.flatten(new_state_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # TODO(b/110122868): Support a "most specific compatible structure"
        # method for combining structures, to avoid using legacy structures
        # in this method.
        self._state_structure = structure.convert_legacy_structure(
            old_state_types,
            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
            old_state_classes)

    self._scan_func = wrapped_func
    self._scan_func.function.add_to_graph(ops.get_default_graph())
    # pylint: disable=protected-access
    if compat.forward_compatible(2019, 10,
                                 15) or use_default_device is not None:
      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
          self._input_dataset._variant_tensor,
          structure.to_tensor_list(self._state_structure, self._initial_state),
          self._scan_func.function.captured_inputs,
          f=self._scan_func.function,
          preserve_cardinality=True,
          use_default_device=use_default_device,
          **self._flat_structure)
    else:
      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
          self._input_dataset._variant_tensor,
          structure.to_tensor_list(self._state_structure, self._initial_state),
          self._scan_func.function.captured_inputs,
          f=self._scan_func.function,
          preserve_cardinality=True,
          **self._flat_structure)
    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)