def __init__(self, elements, name=None): if not elements: raise ValueError( "Invalid `elements`. `elements` should not be empty.") if not isinstance(elements, list): raise ValueError("Invalid `elements`. `elements` must be a list.") elements = [ structure.normalize_element(element) for element in elements ] type_specs = [ structure.type_spec_from_value(element) for element in elements ] # Check that elements have same nested structure. num_elements = len(elements) for i in range(1, num_elements): nest.assert_same_structure(type_specs[0], type_specs[i]) # Infer elements' supershape. flattened_type_specs = [ nest.flatten(type_spec) for type_spec in type_specs ] num_tensors_per_element = len(flattened_type_specs[0]) flattened_structure = [None] * num_tensors_per_element for i in range(num_tensors_per_element): flattened_structure[i] = flattened_type_specs[0][i] for j in range(1, num_elements): flattened_structure[i] = flattened_structure[ i].most_specific_common_supertype( [flattened_type_specs[j][i]]) if not isinstance(type_specs[0], dataset_ops.DatasetSpec): self._tensors = list( itertools.chain.from_iterable( [nest.flatten(element) for element in elements])) else: self._tensors = [x._variant_tensor for x in elements] self._structure = nest.pack_sequence_as(type_specs[0], flattened_structure) self._name = name variant_tensor = gen_experimental_dataset_ops.list_dataset( self._tensors, output_types=self._flat_types, output_shapes=self._flat_shapes, metadata=self._metadata.SerializeToString()) super(_ListDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, initial_state, scan_func, use_default_device=None): """See `scan()` for details.""" self._input_dataset = input_dataset self._initial_state = structure.normalize_element(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.type_spec_from_value(self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access self._state_structure) for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access self._state_structure) for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access self._state_structure) self._element_spec = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access if compat.forward_compatible(2019, 10, 15) or use_default_device is not None: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, use_default_device=use_default_device, **self._flat_structure) else: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **self._flat_structure) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)