def adapt(self, data, batch_size=None, steps=None, reset_state=True): """Fits the state of the preprocessing layer to the data being passed. Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. batch_size: Integer or `None`. Number of samples per state update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the form of datasets, generators, or `keras.utils.Sequence` instances (since they generate batches). steps: Integer or `None`. Total number of steps (batches of samples) When training with input tensors such as TensorFlow data tensors, the default `None` is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a `tf.data` dataset, and 'steps' is None, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the `steps` argument. This argument is not supported with array inputs. reset_state: Optional argument specifying whether to clear the state of the layer at the start of the call to `adapt`, or whether to start from the existing state. This argument may not be relevant to all preprocessing layers: a subclass of PreprocessingLayer may choose to throw if 'reset_state' is set to False. """ _disallow_inside_tf_function('adapt') if not version_utils.should_use_v2(): raise RuntimeError('`adapt` is only supported in tensorflow v2.') # pylint: disable=g-doc-exception if not self.stateful: return if not self.streaming and self._is_adapted and not reset_state: raise ValueError( '{} does not supporting calling `adapt` twice without ' 'resetting the state.'.format(self.__class__.__name__)) if not self._is_compiled: self.compile() # Compile with defaults. if self.built and reset_state: self.reset_state() data_handler = data_adapter.DataHandler( data, batch_size=batch_size, steps_per_epoch=steps, epochs=1, steps_per_execution=self._steps_per_execution, distribute=False) self._adapt_function = self.make_adapt_function() for _, iterator in data_handler.enumerate_epochs(): with data_handler.catch_stop_iteration(): for _ in data_handler.steps(): self._adapt_function(iterator) if data_handler.should_sync: context.async_wait() self.finalize_state() self._is_adapted = True
def test_infinite_dataset_with_steps_per_epoch(self): data = tf.data.Dataset.from_tensor_slices([0, 1, 2]).batch(1).repeat() data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2, steps_per_epoch=3) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator).numpy()) returned_data.append(epoch_data) self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
def test_single_x_input_no_tuple_wrapping(self, use_numpy): x = np.ones((10, 1)) if use_numpy: batch_size = 2 else: x = tf.data.Dataset.from_tensor_slices(x).batch(2) batch_size = None data_handler = data_adapter.DataHandler(x, batch_size=batch_size) for _, iterator in data_handler.enumerate_epochs(): for _ in data_handler.steps(): # Check that single x input is not wrapped in a tuple. self.assertIsInstance(next(iterator), tf.Tensor)
def test_class_weight_user_errors(self): with self.assertRaisesRegex(ValueError, 'to be a dict with keys'): data_adapter.DataHandler( x=[[0], [1], [2]], y=[[2], [1], [0]], batch_size=1, sample_weight=[[1.], [2.], [4.]], class_weight={ 0: 0.5, 1: 1., 3: 1.5 # Skips class `2`. }) with self.assertRaisesRegex(ValueError, 'with a single output'): data_adapter.DataHandler(x=np.ones((10, 1)), y=[np.ones((10, 1)), np.zeros((10, 1))], batch_size=2, class_weight={ 0: 0.5, 1: 1., 2: 1.5 })
def test_list_of_scalars(self): data_handler = data_adapter.DataHandler([[0], [1], [2]], epochs=2, steps_per_epoch=3) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertEqual(returned_data, [[([0], ), ([1], ), ([2], )], [([0], ), ([1], ), ([2], )]])
def test_finite_dataset_with_steps_per_epoch_exact_size(self): data = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) # If user specifies exact size of `Dataset` as `steps_per_epoch`, # create a new iterator each epoch. data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2, steps_per_epoch=4) self.assertTrue(data_handler._adapter.should_recreate_iterator()) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator).numpy()) returned_data.append(epoch_data) self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
def test_finite_dataset_with_steps_per_epoch(self): data = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1) # User can choose to only partially consume `Dataset`. data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2, steps_per_epoch=2) self.assertEqual(data_handler.inferred_steps, 2) self.assertFalse(data_handler._adapter.should_recreate_iterator()) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator).numpy()) returned_data.append(epoch_data) self.assertEqual(returned_data, [[0, 1], [2, 3]])
def test_insufficient_data(self): ds = tf.data.Dataset.from_tensor_slices([0, 1]) ds = ds.filter(lambda *args, **kwargs: True) data_handler = data_adapter.DataHandler(ds, initial_epoch=0, epochs=2, steps_per_epoch=3) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): with data_handler.catch_stop_iteration(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertTrue(data_handler._insufficient_data) self.assertEqual(returned_data, [[0, 1]])
def test_composite_tensor(self): st = tf.SparseTensor(indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1]) data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate( tf.nest.map_structure(tf.sparse.to_dense, returned_data)) self.assertEqual(returned_data, [[([0], ), ([1], ), ([2], )], [([0], ), ([1], ), ([2], )]])
def test_generator(self): def generator(): for _ in range(2): for step in range(3): yield (tf.convert_to_tensor([step]), ) data_handler = data_adapter.DataHandler(generator(), epochs=2, steps_per_epoch=3) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertEqual(returned_data, [[([0], ), ([1], ), ([2], )], [([0], ), ([1], ), ([2], )]])
def test_numpy(self): x = np.array([0, 1, 2]) y = np.array([0, 2, 4]) sw = np.array([0, 4, 8]) data_handler = data_adapter.DataHandler(x=x, y=y, sample_weight=sw, batch_size=1, epochs=2) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertEqual(returned_data, [[(0, 0, 0), (1, 2, 4), (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]])
def test_unknown_cardinality_dataset_without_steps_per_epoch(self): ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) filtered_ds = ds.filter(lambda x: x < 4) self.assertEqual( tf.data.experimental.cardinality(filtered_ds).numpy(), tf.data.experimental.UNKNOWN_CARDINALITY) data_handler = data_adapter.DataHandler(filtered_ds, initial_epoch=0, epochs=2) self.assertEqual(data_handler.inferred_steps, None) self.assertTrue(data_handler._adapter.should_recreate_iterator()) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] with data_handler.catch_stop_iteration(): for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]]) self.assertEqual(data_handler.inferred_steps, 4)
def test_unknown_cardinality_dataset_with_steps_per_epoch(self): ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) filtered_ds = ds.filter(lambda x: x < 4) self.assertEqual( tf.data.experimental.cardinality(filtered_ds).numpy(), tf.data.experimental.UNKNOWN_CARDINALITY) # User can choose to only partially consume `Dataset`. data_handler = data_adapter.DataHandler(filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2) self.assertFalse(data_handler._adapter.should_recreate_iterator()) returned_data = [] for _, iterator in data_handler.enumerate_epochs(): epoch_data = [] for _ in data_handler.steps(): epoch_data.append(next(iterator)) returned_data.append(epoch_data) returned_data = self.evaluate(returned_data) self.assertEqual(returned_data, [[0, 1], [2, 3]]) self.assertEqual(data_handler.inferred_steps, 2)
def adapt(self, data, batch_size=None, steps=None): """Fits the state of the preprocessing layer to the data being passed. After calling `adapt` on a layer, a preprocessing layer's state will not update during training. In order to make preprocessing layers efficient in any distribution context, they are kept constant with respect to any compiled `tf.Graph`s that call the layer. This does not affect the layer use when adapting each layer only once, but if you adapt a layer multiple times you will need to take care to re-compile any compiled functions as follows: * If you are adding a preprocessing layer to a `keras.Model`, you need to call `model.compile` after each subsequent call to `adapt`. * If you are calling a preprocessing layer inside `tf.data.Dataset.map`, you should call `map` again on the input `tf.data.Dataset` after each `adapt`. * If you are using a `tf.function` directly which calls a preprocessing layer, you need to call `tf.function` again on your callable after each subsequent call to `adapt`. `tf.keras.Model` example with multiple adapts: >>> layer = tf.keras.layers.Normalization( ... axis=None) >>> layer.adapt([0, 2]) >>> model = tf.keras.Sequential(layer) >>> model.predict([0, 1, 2]) array([-1., 0., 1.], dtype=float32) >>> layer.adapt([-1, 1]) >>> model.compile() # This is needed to re-compile model.predict! >>> model.predict([0, 1, 2]) array([0., 1., 2.], dtype=float32) `tf.data.Dataset` example with multiple adapts: >>> layer = tf.keras.layers.Normalization( ... axis=None) >>> layer.adapt([0, 2]) >>> input_ds = tf.data.Dataset.range(3) >>> normalized_ds = input_ds.map(layer) >>> list(normalized_ds.as_numpy_iterator()) [array([-1.], dtype=float32), array([0.], dtype=float32), array([1.], dtype=float32)] >>> layer.adapt([-1, 1]) >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset. >>> list(normalized_ds.as_numpy_iterator()) [array([0.], dtype=float32), array([1.], dtype=float32), array([2.], dtype=float32)] `adapt()` is meant only as a single machine utility to compute layer state. To analyze a dataset that cannot fit on a single machine, see [Tensorflow Transform](https://www.tensorflow.org/tfx/transform/get_started) for a multi-machine, map-reduce solution. Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. batch_size: Integer or `None`. Number of samples per state update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the form of datasets, generators, or `keras.utils.Sequence` instances (since they generate batches). steps: Integer or `None`. Total number of steps (batches of samples) When training with input tensors such as TensorFlow data tensors, the default `None` is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a `tf.data` dataset, and 'steps' is None, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the `steps` argument. This argument is not supported with array inputs. """ _disallow_inside_tf_function("adapt") if not version_utils.should_use_v2(): raise RuntimeError( "`adapt` is only supported in tensorflow v2." ) # pylint: disable=g-doc-exception if not self._is_compiled: self.compile() # Compile with defaults. if self.built: self.reset_state() data_handler = data_adapter.DataHandler( data, batch_size=batch_size, steps_per_epoch=steps, epochs=1, steps_per_execution=self._steps_per_execution, distribute=False, ) self._adapt_function = self.make_adapt_function() for _, iterator in data_handler.enumerate_epochs(): with data_handler.catch_stop_iteration(): for _ in data_handler.steps(): self._adapt_function(iterator) if data_handler.should_sync: context.async_wait() self.finalize_state() self._is_adapted = True