Ejemplo n.º 1
0
    def test_validation_split_none(self):
        train_sw, val_sw = data_adapter.train_validation_split(
            None, validation_split=0.2)
        self.assertIsNone(train_sw)
        self.assertIsNone(val_sw)

        (_, train_sw), (_, val_sw) = data_adapter.train_validation_split(
            (np.ones((10, 1)), None), validation_split=0.2)
        self.assertIsNone(train_sw)
        self.assertIsNone(val_sw)
  def test_validation_split_unshuffled(self, use_numpy):
    if use_numpy:
      x = np.array([0, 1, 2, 3, 4])
      y = np.array([0, 2, 4, 6, 8])
      sw = np.array([0, 4, 8, 12, 16])
    else:
      x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4])
      y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8])
      sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16])

    (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
        data_adapter.train_validation_split((x, y, sw), validation_split=0.2))

    if use_numpy:
      train_x = ops.convert_to_tensor_v2_with_dispatch(train_x)
      train_y = ops.convert_to_tensor_v2_with_dispatch(train_y)
      train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw)
      val_x = ops.convert_to_tensor_v2_with_dispatch(val_x)
      val_y = ops.convert_to_tensor_v2_with_dispatch(val_y)
      val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw)

    self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
    self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])
    self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12])

    self.assertEqual(val_x.numpy().tolist(), [4])
    self.assertEqual(val_y.numpy().tolist(), [8])
    self.assertEqual(val_sw.numpy().tolist(), [16])
Ejemplo n.º 3
0
    def test_validation_split_shuffled(self, use_numpy):
        if use_numpy:
            x = np.array([0, 1, 2, 3, 4])
            y = np.array([0, 2, 4, 6, 8])
            sw = np.array([0, 4, 8, 12, 16])
        else:
            x = ops.convert_to_tensor_v2([0, 1, 2, 3, 4])
            y = ops.convert_to_tensor_v2([0, 2, 4, 6, 8])
            sw = ops.convert_to_tensor_v2([0, 4, 8, 12, 16])

        (train_x, train_y,
         train_sw), (val_x, val_y,
                     val_sw) = (data_adapter.train_validation_split(
                         (x, y, sw), validation_split=0.2))

        self.assertEqual(int(train_x.shape[0]), 4)
        self.assertEqual(int(train_y.shape[0]), 4)
        self.assertEqual(int(train_sw.shape[0]), 4)
        for i in range(4):
            # Check that all arrays were shuffled in identical order.
            self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
            self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())

        self.assertEqual(int(val_x.shape[0]), 1)
        self.assertEqual(int(val_y.shape[0]), 1)
        self.assertEqual(int(val_sw.shape[0]), 1)
        for i in range(1):
            # Check that all arrays were shuffled in identical order.
            self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
            self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())

        # Check that arrays contain expected values.
        self.assertEqual(
            sorted(
                array_ops.concat([train_x, val_x], axis=0).numpy().tolist()),
            sorted(ops.convert_to_tensor_v2(x).numpy().tolist()))
        self.assertEqual(
            sorted(
                array_ops.concat([train_y, val_y], axis=0).numpy().tolist()),
            sorted(ops.convert_to_tensor_v2(y).numpy().tolist()))
        self.assertEqual(
            sorted(
                array_ops.concat([train_sw, val_sw], axis=0).numpy().tolist()),
            sorted(ops.convert_to_tensor_v2(sw).numpy().tolist()))
Ejemplo n.º 4
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            auto_switch=True,
            retry_fit=True,
            absorb=True,
            train_after_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            revert_after_fit=False):
        """
        Custom fit function for the context model

        auto_switch:        Enable/disable autonomous context switching
        train_after_switch:
        retry_fit:          Locate the next fitting context by re-performing fit.
        absorb:             Reset the switch sequence counter upon successful training.
                            This is mainly used to maintain switch sequencing for temporally-extended tasks
        revert_after_fit    This is a debug parameter to revert weights after performing a fit. This is used
                            to calculate the context deltas without incorrectly learning while auto switching
                            is disabled
        """

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            self.initialize_fit()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched_during_epoch = False  # Indicate if the model has attempted at least one switch during this epoch
                switched = True  # Indicate if the model switched on the most recent fit iteration
                weights = backend.batch_get_value(self.trainable_variables)
                # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits
                while switched and (retry_fit or not switched_during_epoch):
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)

                    # Perform a fit call
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, auto_switch, absorb, retry_fit, verbose)
                        switched_during_epoch |= switched

                        # If a switch occurred, we need to restore the weights
                        if switched or (switched_during_epoch
                                        and not train_after_switch
                                        ) or revert_after_fit:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
Ejemplo n.º 5
0
 def test_validation_split_user_error(self):
     with self.assertRaisesRegexp(ValueError,
                                  'is only supported for Tensors'):
         data_adapter.train_validation_split(lambda: np.ones((10, 1)),
                                             validation_split=0.2)
Ejemplo n.º 6
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):
        """ From tf.keras.Model. """
        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')
        training._disallow_inside_tf_function('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight), validation_split=validation_split))

        if validation_data:
            val_x, val_y, val_sample_weight = (
                data_adapter.unpack_x_y_sample_weight(validation_data))

        with self.distribute_strategy.scope(), \
             training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = data_adapter.DataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self,
                steps_per_execution=self._steps_per_execution)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=verbose != 0,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            self._train_counter.assign(0)
            callbacks.on_train_begin()
            training_logs = None
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (  # pylint: disable=protected-access
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                with data_handler.catch_stop_iteration():
                    if self._update_cycle > 1:
                        self._grad_accumulator.reset()
                    for step in data_handler.steps():
                        with trace.Trace(
                            'TraceContext',
                            graph_type='train',
                            epoch_num=epoch,
                            step_num=step,
                            batch_size=batch_size):
                            callbacks.on_train_batch_begin(step)
                            if self._update_cycle > 1:
                                for _ in range(self._update_cycle - 1):
                                    self.accumulate_function(iterator)
                            tmp_logs = train_function(iterator)
                            if data_handler.should_sync:
                                context.async_wait()
                            logs = tmp_logs  # No error, now safe to assign to logs.
                            end_step = step + data_handler.step_increment
                            callbacks.on_train_batch_end(end_step, logs)
                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    # Create data_handler for evaluation and cache it.
                    if getattr(self, '_eval_data_handler', None) is None:
                        self._eval_data_handler = data_adapter.DataHandler(
                            x=val_x,
                            y=val_y,
                            sample_weight=val_sample_weight,
                            batch_size=validation_batch_size or batch_size,
                            steps_per_epoch=validation_steps,
                            initial_epoch=0,
                            epochs=1,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            model=self,
                            steps_per_execution=self._steps_per_execution)
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                training_logs = epoch_logs
                if self.stop_training:
                    break

            # If eval data_hanlder exists, delete it after all epochs are done.
            if getattr(self, '_eval_data_handler', None) is not None:
                del self._eval_data_handler
            callbacks.on_train_end(logs=training_logs)
            return self.history
 def test_validation_split_examples_too_few(self):
   with self.assertRaisesRegex(ValueError, 'not sufficient to split it'):
     data_adapter.train_validation_split(
         np.ones((1, 10)), validation_split=0.2)
Ejemplo n.º 8
0
    def fit(
        self,
        x: Optional[
            Union[np.ndarray, tf.Tensor, tf.data.Dataset, tf.keras.utils.Sequence]
        ] = None,
        y: Optional[
            Union[np.ndarray, tf.Tensor, tf.data.Dataset, tf.keras.utils.Sequence]
        ] = None,
        batch_size: Optional[int] = None,
        epochs: int = 1,
        verbose: int = 1,
        callbacks: Optional[List[Callback]] = None,
        validation_split: float = 0.0,
        validation_data: Optional[Any] = None,
        shuffle: bool = True,
        class_weight: Optional[Dict[int, float]] = None,
        sample_weight: Optional[np.ndarray] = None,
        initial_epoch: int = 0,
        steps_per_epoch: Optional[int] = None,
        validation_steps: Optional[int] = None,
        validation_batch_size: Optional[int] = None,
        validation_freq: int = 1,
        max_queue_size: int = 10,
        workers: int = 1,
        use_multiprocessing: bool = False,
    ):
        """Trains the model for a fixed number of epochs (iterations on a dataset).

        Args:
            x: Input data.
            y: Target data.
            batch_size: Number of samples per gradient update.
            epochs: Number of epochs to train the model.
            verbose: Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per
                     epoch.
            callbacks: List of `keras.callbacks.Callback` instances.
            validation_split: Fraction of the training data to be used as validation
                              data.
            validation_data: Data on which to evaluate the loss and any model metrics
                             at the end of each epoch.
            shuffle: whether to shuffle the training data before each epoch
            class_weight: Optional dictionary mapping class indices (integers)
                          to a weight (float) value, used for weighting the loss
                          function (during training only).
            sample_weight: Optional Numpy array of weights for
                           the training samples, used for weighting the loss function
                           (during training only).
            initial_epoch: Epoch at which to start training
            steps_per_epoch: Total number of steps (batches of samples)
                             before declaring one epoch finished and starting the
                             next epoch.
            validation_steps: Total number of steps (batches of
                              samples) to draw before stopping when performing
                              validation at the end of every epoch.
            validation_batch_size: Number of samples per validation batch.
            validation_freq: specifies how many training epochs to run before a
                             new validation run is performed
            max_queue_size: Maximum size for the generator queue.
            workers: Maximum number of processes to spin up
                     when using process-based threading.
            use_multiprocessing: If `True`, use process-based threading.

        Returns:
            A `History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        Raises:
            RuntimeError: 1. If the model was never compiled or,
            2. If `model.fit` is  wrapped in `tf.function`.

            ValueError: In case of mismatch between the provided input data
                and what the model expects.
        """
        training._keras_api_gauge.get_cell("fit").set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph("Model", "fit")
        self._assert_compile_was_called()
        self._check_call_args("fit")
        training._disallow_inside_tf_function("fit")

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (
                (x, y, sample_weight),
                validation_data,
            ) = data_adapter.train_validation_split(
                (x, y, sample_weight), validation_split=validation_split
            )

        if validation_data:
            val_x, val_y, val_sample_weight = data_adapter.unpack_x_y_sample_weight(
                validation_data
            )

        with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState(
            self
        ):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            # Use our own custom data handler to handle increasing batch size
            data_handler = CustomDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self,
                steps_per_execution=self._steps_per_execution,
            )

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, training.callbacks_module.CallbackList):
                callbacks = training.callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=verbose != 0,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps,
                )

            self.stop_training = False
            train_function = self.make_train_function()
            self._train_counter.assign(0)
            callbacks.on_train_begin()
            training_logs = None
            # Handle fault-tolerance for multi-worker.
            data_handler._initial_epoch = self._maybe_load_initial_epoch_from_ckpt(  # pylint: disable=protected-access
                initial_epoch
            )
            for epoch, iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                with data_handler.catch_stop_iteration():
                    for step in data_handler.steps():
                        with training.trace.Trace(
                            "TraceContext",
                            graph_type="train",
                            epoch_num=epoch,
                            step_num=step,
                            batch_size=batch_size,
                        ):
                            callbacks.on_train_batch_begin(step)
                            tmp_logs = train_function(iterator)
                            if data_handler.should_sync:
                                context.async_wait()
                            logs = tmp_logs  # No error, now safe to assign to logs.
                            end_step = step + data_handler.step_increment
                            callbacks.on_train_batch_end(end_step, logs)
                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    # Create data_handler for evaluation and cache it.
                    if getattr(self, "_eval_data_handler", None) is None:
                        self._eval_data_handler = CustomDataHandler(
                            x=val_x,
                            y=val_y,
                            sample_weight=val_sample_weight,
                            batch_size=validation_batch_size or batch_size,
                            steps_per_epoch=validation_steps,
                            initial_epoch=0,
                            epochs=1,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            model=self,
                            steps_per_execution=self._steps_per_execution,
                        )
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True,
                    )
                    val_logs = {"val_" + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                training_logs = epoch_logs
                if self.stop_training:
                    break

            # If eval data_hanlder exists, delete it after all epochs are done.
            if getattr(self, "_eval_data_handler", None) is not None:
                del self._eval_data_handler
            callbacks.on_train_end(logs=training_logs)
            return self.history
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            dynamic_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                weights = backend.batch_get_value(self.trainable_variables)
                while switched:
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, dynamic_switch, verbose)
                        # If a switch occurred, we need to restore the weights
                        if switched:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                if self.accumulate_gradients:
                    self.optimizer.apply_gradients(
                        zip(self.accumulated_gradients,
                            self.trainable_variables))

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history