def fit_single( model: tf.keras.Model, train_data, validation_data=None, epochs: int = 1, initial_epoch: int = 0, validation_freq: int = 1, callbacks: Iterable[tf.keras.callbacks.Callback] = (), verbose: bool = True, jit_compile: bool = False, ): """ Optimized keras.Model.fit for training on a single graph. Args: model: keras model to train. train_data: (inputs, labels, sample_weight) or dataset with a single element for training. validation_data: (inputs, labels, sample_weight) or dataset with a single element for validation. epochs: int, maximum number of epochs / steps to train for. initial_epoch: int, starting epoch. validation_freq: int, number of training steps/epochs per validation. callbacks: Iterable of tf.keras.callbacks.Callbacks. verbose: flag resulting in verbose outputs. jit_compile: flag indicating whether train/validation steps are compiled with `jit`. Not all ops are jit compatible, though where they are this may result in speed-ups. Returns: history: `tf.keras.callbacks.History` object. """ train_data = unpack(train_data) validation_data = unpack(validation_data) do_validation = validation_data is not None params = dict( epochs=epochs, verbose=verbose, steps=1, do_validation=do_validation, ) callbacks = list(callbacks) if verbose: callbacks.append(EpochProgbarLogger()) cb = tf.keras.callbacks.CallbackList( callbacks, add_history=True, add_progbar=False, model=model, **params, ) del callbacks train_step = _build_train_step(model, train_data, jit_compile=jit_compile) if validation_data is None: validation_step = None else: validation_step = _build_test_step(model, validation_data, jit_compile=jit_compile) model.stop_training = False cb.on_train_begin(logs=None) # _maybe_load_initial_epoch_from_ckpt behaviour is influenced by # callbacks.experimental.BackupAndRestore initial_epoch = model._maybe_load_initial_epoch_from_ckpt( # pylint: disable=protected-access initial_epoch) logs = None for epoch in range(initial_epoch, epochs): model.reset_metrics() cb.on_epoch_begin(epoch, logs=None) cb.on_train_batch_begin(batch=0) logs = train_step() cb.on_train_batch_end(batch=0, logs=logs) if model.stop_training: break # validation if validation_step is not None and (epoch + 1) % validation_freq == 0: val_logs = validation_step() logs.update({f"val_{k}": v for k, v in val_logs.items()}) cb.on_epoch_end(epoch, logs) if model.stop_training: break cb.on_train_end(logs) return model.history
def fit( model: tf.keras.Model, train_data: tf.data.Dataset, epochs: int = 1, steps_per_epoch: Optional[int] = None, validation_data: tf.data.Dataset = None, validation_steps: Optional[int] = None, callbacks: Tuple[tf.keras.callbacks.Callback, ...] = (), initial_epoch: int = 0, validation_freq: int = 1, track_iterator: bool = False, verbose: bool = True, ) -> tf.keras.callbacks.History: """ Custom fit implementation. Interface is intended to mimic best-practice usage of `tf.keras.Model.fit`. Unlike `tf.keras.Model.fit` `_train_iter` is added as an attribute to model. If using `tf.train.Checkpoint`s to manage training state, this may result in larger files on disk. Args: model: keras model to train. train_data: dataset with (inputs, labels) or (inputs, labels, sample_weights) epochs: total number of epochs to train until. steps_per_epoch: number of steps per epoch. Must be provided if train_data has infinite cardinality. validation_data: optional dataset to perform validation on. validation_steps: number of steps of validation to perform per epoch. callbacks: `tf.keras.callbacks.Callback` instances. initial_epoch: starting epoch. validation_freq: number of epochs between validation. track_iterator: if True, `train_data`'s iterator is added as an attribute to `model`, meaning it will be saved in checkpoint's saving `model`. verbose: controls verbosity of printed output. Returns: `tf.keras.callbacks.History` object. Raises: `AttributeError` if `model` has an existing `_train_iter` attribute and `track_iterator` is True. """ train_func = model.make_train_function() train_iter, steps_per_epoch = as_infinite_iterator(train_data, steps_per_epoch) if hasattr(model, "_train_iter"): raise AttributeError( "Cannot fit model with existing `_train_iter` attribute.") if track_iterator: model._train_iter = train_iter # pylint: disable=protected-access cb = tf.keras.callbacks.CallbackList(callbacks=callbacks, add_history=True, add_progbar=verbose, model=model) cb.set_params( dict(epochs=epochs, verbose=int(verbose), steps=steps_per_epoch)) cb.on_train_begin() initial_epoch = ( model._maybe_load_initial_epoch_from_ckpt( # pylint: disable=protected-access initial_epoch)) training_logs = None model.stop_training = False for epoch in range(initial_epoch, epochs): model.reset_metrics() cb.on_epoch_begin(epoch) logs = None for step in range(steps_per_epoch): cb.on_train_batch_begin(step) logs = train_func(train_iter) cb.on_train_batch_end(step, logs) if model.stop_training: break assert logs is not None epoch_logs = logs if (validation_data is not None and model._should_eval( # pylint: disable=protected-access epoch, validation_freq)): logs = model.evaluate( validation_data, steps=validation_steps, callbacks=cb, return_dict=True, ) epoch_logs.update( {"val_" + name: val for name, val in logs.items()}) cb.on_epoch_end(epoch, epoch_logs) training_logs = epoch_logs if model.stop_training: break cb.on_train_end(logs=training_logs) if track_iterator: del model._train_iter return model.history