def test_validation_split_none(self): train_sw, val_sw = data_adapter.train_validation_split( None, validation_split=0.2) self.assertIsNone(train_sw) self.assertIsNone(val_sw) (_, train_sw), (_, val_sw) = data_adapter.train_validation_split( (np.ones((10, 1)), None), validation_split=0.2) self.assertIsNone(train_sw) self.assertIsNone(val_sw)
def test_validation_split_unshuffled(self, use_numpy): if use_numpy: x = np.array([0, 1, 2, 3, 4]) y = np.array([0, 2, 4, 6, 8]) sw = np.array([0, 4, 8, 12, 16]) else: x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4]) y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8]) sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16]) (train_x, train_y, train_sw), (val_x, val_y, val_sw) = ( data_adapter.train_validation_split((x, y, sw), validation_split=0.2)) if use_numpy: train_x = ops.convert_to_tensor_v2_with_dispatch(train_x) train_y = ops.convert_to_tensor_v2_with_dispatch(train_y) train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw) val_x = ops.convert_to_tensor_v2_with_dispatch(val_x) val_y = ops.convert_to_tensor_v2_with_dispatch(val_y) val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw) self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3]) self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6]) self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12]) self.assertEqual(val_x.numpy().tolist(), [4]) self.assertEqual(val_y.numpy().tolist(), [8]) self.assertEqual(val_sw.numpy().tolist(), [16])
def test_validation_split_shuffled(self, use_numpy): if use_numpy: x = np.array([0, 1, 2, 3, 4]) y = np.array([0, 2, 4, 6, 8]) sw = np.array([0, 4, 8, 12, 16]) else: x = ops.convert_to_tensor_v2([0, 1, 2, 3, 4]) y = ops.convert_to_tensor_v2([0, 2, 4, 6, 8]) sw = ops.convert_to_tensor_v2([0, 4, 8, 12, 16]) (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (data_adapter.train_validation_split( (x, y, sw), validation_split=0.2)) self.assertEqual(int(train_x.shape[0]), 4) self.assertEqual(int(train_y.shape[0]), 4) self.assertEqual(int(train_sw.shape[0]), 4) for i in range(4): # Check that all arrays were shuffled in identical order. self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy()) self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy()) self.assertEqual(int(val_x.shape[0]), 1) self.assertEqual(int(val_y.shape[0]), 1) self.assertEqual(int(val_sw.shape[0]), 1) for i in range(1): # Check that all arrays were shuffled in identical order. self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy()) self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy()) # Check that arrays contain expected values. self.assertEqual( sorted( array_ops.concat([train_x, val_x], axis=0).numpy().tolist()), sorted(ops.convert_to_tensor_v2(x).numpy().tolist())) self.assertEqual( sorted( array_ops.concat([train_y, val_y], axis=0).numpy().tolist()), sorted(ops.convert_to_tensor_v2(y).numpy().tolist())) self.assertEqual( sorted( array_ops.concat([train_sw, val_sw], axis=0).numpy().tolist()), sorted(ops.convert_to_tensor_v2(sw).numpy().tolist()))
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, auto_switch=True, retry_fit=True, absorb=True, train_after_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False, revert_after_fit=False): """ Custom fit function for the context model auto_switch: Enable/disable autonomous context switching train_after_switch: retry_fit: Locate the next fitting context by re-performing fit. absorb: Reset the switch sequence counter upon successful training. This is mainly used to maintain switch sequencing for temporally-extended tasks revert_after_fit This is a debug parameter to revert weights after performing a fit. This is used to calculate the context deltas without incorrectly learning while auto switching is disabled """ training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() self.initialize_fit() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched_during_epoch = False # Indicate if the model has attempted at least one switch during this epoch switched = True # Indicate if the model switched on the most recent fit iteration weights = backend.batch_get_value(self.trainable_variables) # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits while switched and (retry_fit or not switched_during_epoch): self.initialize_epoch(epoch) iterator = iter(dataset) # Perform a fit call with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, auto_switch, absorb, retry_fit, verbose) switched_during_epoch |= switched # If a switch occurred, we need to restore the weights if switched or (switched_during_epoch and not train_after_switch ) or revert_after_fit: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history
def test_validation_split_user_error(self): with self.assertRaisesRegexp(ValueError, 'is only supported for Tensors'): data_adapter.train_validation_split(lambda: np.ones((10, 1)), validation_split=0.2)
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False): """ From tf.keras.Model. """ training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') training._disallow_inside_tf_function('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split)) if validation_data: val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) with self.distribute_strategy.scope(), \ training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = data_adapter.DataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() self._train_counter.assign(0) callbacks.on_train_begin() training_logs = None # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( # pylint: disable=protected-access self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) with data_handler.catch_stop_iteration(): if self._update_cycle > 1: self._grad_accumulator.reset() for step in data_handler.steps(): with trace.Trace( 'TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) if self._update_cycle > 1: for _ in range(self._update_cycle - 1): self.accumulate_function(iterator) tmp_logs = train_function(iterator) if data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. end_step = step + data_handler.step_increment callbacks.on_train_batch_end(end_step, logs) epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval(epoch, validation_freq): # Create data_handler for evaluation and cache it. if getattr(self, '_eval_data_handler', None) is None: self._eval_data_handler = data_adapter.DataHandler( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps_per_epoch=validation_steps, initial_epoch=0, epochs=1, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = {'val_' + name: val for name, val in val_logs.items()} epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) training_logs = epoch_logs if self.stop_training: break # If eval data_hanlder exists, delete it after all epochs are done. if getattr(self, '_eval_data_handler', None) is not None: del self._eval_data_handler callbacks.on_train_end(logs=training_logs) return self.history
def test_validation_split_examples_too_few(self): with self.assertRaisesRegex(ValueError, 'not sufficient to split it'): data_adapter.train_validation_split( np.ones((1, 10)), validation_split=0.2)
def fit( self, x: Optional[ Union[np.ndarray, tf.Tensor, tf.data.Dataset, tf.keras.utils.Sequence] ] = None, y: Optional[ Union[np.ndarray, tf.Tensor, tf.data.Dataset, tf.keras.utils.Sequence] ] = None, batch_size: Optional[int] = None, epochs: int = 1, verbose: int = 1, callbacks: Optional[List[Callback]] = None, validation_split: float = 0.0, validation_data: Optional[Any] = None, shuffle: bool = True, class_weight: Optional[Dict[int, float]] = None, sample_weight: Optional[np.ndarray] = None, initial_epoch: int = 0, steps_per_epoch: Optional[int] = None, validation_steps: Optional[int] = None, validation_batch_size: Optional[int] = None, validation_freq: int = 1, max_queue_size: int = 10, workers: int = 1, use_multiprocessing: bool = False, ): """Trains the model for a fixed number of epochs (iterations on a dataset). Args: x: Input data. y: Target data. batch_size: Number of samples per gradient update. epochs: Number of epochs to train the model. verbose: Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. callbacks: List of `keras.callbacks.Callback` instances. validation_split: Fraction of the training data to be used as validation data. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. shuffle: whether to shuffle the training data before each epoch class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only). initial_epoch: Epoch at which to start training steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. validation_steps: Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. validation_batch_size: Number of samples per validation batch. validation_freq: specifies how many training epochs to run before a new validation run is performed max_queue_size: Maximum size for the generator queue. workers: Maximum number of processes to spin up when using process-based threading. use_multiprocessing: If `True`, use process-based threading. Returns: A `History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). Raises: RuntimeError: 1. If the model was never compiled or, 2. If `model.fit` is wrapped in `tf.function`. ValueError: In case of mismatch between the provided input data and what the model expects. """ training._keras_api_gauge.get_cell("fit").set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph("Model", "fit") self._assert_compile_was_called() self._check_call_args("fit") training._disallow_inside_tf_function("fit") if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. ( (x, y, sample_weight), validation_data, ) = data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split ) if validation_data: val_x, val_y, val_sample_weight = data_adapter.unpack_x_y_sample_weight( validation_data ) with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState( self ): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. # Use our own custom data handler to handle increasing batch size data_handler = CustomDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution, ) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, training.callbacks_module.CallbackList): callbacks = training.callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps, ) self.stop_training = False train_function = self.make_train_function() self._train_counter.assign(0) callbacks.on_train_begin() training_logs = None # Handle fault-tolerance for multi-worker. data_handler._initial_epoch = self._maybe_load_initial_epoch_from_ckpt( # pylint: disable=protected-access initial_epoch ) for epoch, iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with training.trace.Trace( "TraceContext", graph_type="train", epoch_num=epoch, step_num=step, batch_size=batch_size, ): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) if data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. end_step = step + data_handler.step_increment callbacks.on_train_batch_end(end_step, logs) epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval(epoch, validation_freq): # Create data_handler for evaluation and cache it. if getattr(self, "_eval_data_handler", None) is None: self._eval_data_handler = CustomDataHandler( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps_per_epoch=validation_steps, initial_epoch=0, epochs=1, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self, steps_per_execution=self._steps_per_execution, ) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True, ) val_logs = {"val_" + name: val for name, val in val_logs.items()} epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) training_logs = epoch_logs if self.stop_training: break # If eval data_hanlder exists, delete it after all epochs are done. if getattr(self, "_eval_data_handler", None) is not None: del self._eval_data_handler callbacks.on_train_end(logs=training_logs) return self.history
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, dynamic_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False): training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched = True weights = backend.batch_get_value(self.trainable_variables) while switched: self.initialize_epoch(epoch) iterator = iter(dataset) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, dynamic_switch, verbose) # If a switch occurred, we need to restore the weights if switched: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) if self.accumulate_gradients: self.optimizer.apply_gradients( zip(self.accumulated_gradients, self.trainable_variables)) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history