def test_convert_to_generator_like(self, input_fn, inputs): expected_batches = 5 data = input_fn(self, inputs, expected_batches) # Dataset and Iterator not supported in Legacy Graph mode. if (not context.executing_eagerly() and isinstance( data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): return generator, steps = training_generator.convert_to_generator_like( data, batch_size=2, steps_per_epoch=expected_batches) self.assertEqual(steps, expected_batches) for _ in range(expected_batches): outputs = next(generator) nest.assert_same_structure(outputs, inputs)
def test_convert_to_generator_like(self, input_fn, inputs): expected_batches = 5 data = input_fn(self, inputs, expected_batches) # Dataset and Iterator not supported in Legacy Graph mode. if (not context.executing_eagerly() and isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): return generator, steps = training_generator.convert_to_generator_like( data, batch_size=2, steps_per_epoch=expected_batches) self.assertEqual(steps, expected_batches) for _ in range(expected_batches): outputs = next(generator) nest.assert_same_structure(outputs, inputs)
def adapt(self, data, reset_state=True): """Fits the state of the preprocessing layer to the data being passed. Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. reset_state: Optional argument specifying whether to clear the state of the layer at the start of the call to `adapt`, or whether to start from the existing state. Subclasses may choose to throw if reset_state is set to 'False'. """ if reset_state: accumulator = None else: accumulator = self._combiner.restore(self._restore_updates()) if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray)): raise ValueError( 'adapt() requires a Dataset or a Numpy array as input.') if isinstance(data, dataset_ops.DatasetV2): # Validate the datasets to try and ensure we haven't been passed one with # infinite size. That would cause an infinite loop here. if self._dataset_is_infinite(data): raise ValueError( 'The dataset passed to "adapt()" has an infinite number of ' 'elements. Please use dataset.take(...) to make the number ' 'of elements finite.') next_data = self._get_dataset_iterator(data) else: generator, _ = training_generator.convert_to_generator_like( data, batch_size=len(data)) # If the data is not a dataset, we can iterate over it using next(foo); # here, we wrap that into a callable. next_data = lambda: next(generator) # TODO(momernick): Some sort of status bar? # TODO(momernick): Implement parallel processing here? try: data_element = next_data() # First, see if the layer is built or not. If it is not, then we must # build it. if not self.built: try: # If this is a Numpy array or tensor, we can get shape from .shape. # If not, an attribute error will be thrown (and we can assume the # input data is a scalar with shape None. shape = data_element.shape except AttributeError: shape = None self.build(shape) # Once we have built the Layer, we can process the input data. We do so # until we've gotten an exception indicating that we have no more data. while True: accumulator = self._combiner.compute(data_element, accumulator) data_element = next_data() # Note that this belongs to the outer indentation of 'try' - we need to # catch exceptions resulting from the first 'next_data()' invocation as # well. except (StopIteration, errors.OutOfRangeError): pass updates = self._combiner.extract(accumulator) self._set_state_variables(updates)
def adapt(self, data, reset_state=True): """Fits the state of the preprocessing layer to the data being passed. Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. reset_state: Optional argument specifying whether to clear the state of the layer at the start of the call to `adapt`, or whether to start from the existing state. Subclasses may choose to throw if reset_state is set to 'False'. """ if reset_state: accumulator = None else: accumulator = self._combiner.restore(self._restore_updates()) if isinstance(data, (list, tuple)): data = ops.convert_to_tensor_v2(data) if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray, ops.Tensor, ragged_tensor.RaggedTensor)): raise ValueError( '`adapt()` requires a batched Dataset, a Tensor, ' 'or a Numpy array as input, ' 'got {}'.format(type(data))) if isinstance(data, dataset_ops.DatasetV2): # Validate the datasets to try and ensure we haven't been passed one with # infinite size. That would cause an infinite loop here. if tf_utils.dataset_is_infinite(data): raise ValueError( 'The dataset passed to `adapt()` has an infinite number of ' 'elements. Please use `dataset.take(...)` to make the number ' 'of elements finite.') next_data = self._get_dataset_iterator(data) # TODO(fchollet): consider checking if the dataset is already batched # and otherwise batching it. elif isinstance(data, (ops.Tensor, ragged_tensor.RaggedTensor)): next_data = self._get_dataset_iterator( dataset_ops.Dataset.from_tensor_slices(data).batch(512)) else: generator, _ = training_generator.convert_to_generator_like( data, batch_size=512) # If the data is not a dataset, we can iterate over it using next(foo); # here, we wrap that into a callable. next_data = lambda: next(generator) # TODO(momernick): Some sort of status bar? # TODO(momernick): Implement parallel processing here? try: data_element = next_data() # First, see if the layer is built or not. If it is not, then we must # build it. if not self.built: try: # If this is a Numpy array or tensor, we can get shape from .shape. # If not, an attribute error will be thrown. data_shape = data_element.shape data_shape_nones = tuple([None]*len(data_element.shape)) except AttributeError: # The input has an unknown number of dimensions. data_shape = None data_shape_nones = None # TODO (b/159261555): move this to base layer build. batch_input_shape = getattr(self, '_batch_input_shape', None) if batch_input_shape is None: # Set the number of dimensions. self._batch_input_shape = data_shape_nones self.build(data_shape) # Once we have built the Layer, we can process the input data. We do so # until we've gotten an exception indicating that we have no more data. while True: accumulator = self._combiner.compute(data_element, accumulator) data_element = next_data() # Note that this belongs to the outer indentation of 'try' - we need to # catch exceptions resulting from the first 'next_data()' invocation as # well. except (StopIteration, errors.OutOfRangeError): pass updates = self._combiner.extract(accumulator) self._set_state_variables(updates)