예제 #1
0
파일: train.py 프로젝트: jackd/graph-tf
def fit_single(
    model: tf.keras.Model,
    train_data,
    validation_data=None,
    epochs: int = 1,
    initial_epoch: int = 0,
    validation_freq: int = 1,
    callbacks: Iterable[tf.keras.callbacks.Callback] = (),
    verbose: bool = True,
    jit_compile: bool = False,
):
    """
    Optimized keras.Model.fit for training on a single graph.

    Args:
        model: keras model to train.
        train_data: (inputs, labels, sample_weight) or dataset with a
            single element for training.
        validation_data: (inputs, labels, sample_weight) or dataset with a
            single element for validation.
        epochs: int, maximum number of epochs / steps to train for.
        initial_epoch: int, starting epoch.
        validation_freq: int, number of training steps/epochs per validation.
        callbacks: Iterable of tf.keras.callbacks.Callbacks.
        verbose: flag resulting in verbose outputs.
        jit_compile: flag indicating whether train/validation steps are compiled
            with `jit`. Not all ops are jit compatible, though where they are this may
            result in speed-ups.

    Returns:
        history: `tf.keras.callbacks.History` object.
    """
    train_data = unpack(train_data)
    validation_data = unpack(validation_data)
    do_validation = validation_data is not None

    params = dict(
        epochs=epochs,
        verbose=verbose,
        steps=1,
        do_validation=do_validation,
    )
    callbacks = list(callbacks)
    if verbose:
        callbacks.append(EpochProgbarLogger())

    cb = tf.keras.callbacks.CallbackList(
        callbacks,
        add_history=True,
        add_progbar=False,
        model=model,
        **params,
    )
    del callbacks
    train_step = _build_train_step(model, train_data, jit_compile=jit_compile)
    if validation_data is None:
        validation_step = None
    else:
        validation_step = _build_test_step(model,
                                           validation_data,
                                           jit_compile=jit_compile)

    model.stop_training = False
    cb.on_train_begin(logs=None)
    # _maybe_load_initial_epoch_from_ckpt behaviour is influenced by
    # callbacks.experimental.BackupAndRestore
    initial_epoch = model._maybe_load_initial_epoch_from_ckpt(  # pylint: disable=protected-access
        initial_epoch)

    logs = None
    for epoch in range(initial_epoch, epochs):
        model.reset_metrics()
        cb.on_epoch_begin(epoch, logs=None)
        cb.on_train_batch_begin(batch=0)
        logs = train_step()
        cb.on_train_batch_end(batch=0, logs=logs)
        if model.stop_training:
            break
        # validation
        if validation_step is not None and (epoch + 1) % validation_freq == 0:
            val_logs = validation_step()
            logs.update({f"val_{k}": v for k, v in val_logs.items()})
        cb.on_epoch_end(epoch, logs)
        if model.stop_training:
            break

    cb.on_train_end(logs)
    return model.history
예제 #2
0
파일: models.py 프로젝트: jackd/kblocks
def fit(
    model: tf.keras.Model,
    train_data: tf.data.Dataset,
    epochs: int = 1,
    steps_per_epoch: Optional[int] = None,
    validation_data: tf.data.Dataset = None,
    validation_steps: Optional[int] = None,
    callbacks: Tuple[tf.keras.callbacks.Callback, ...] = (),
    initial_epoch: int = 0,
    validation_freq: int = 1,
    track_iterator: bool = False,
    verbose: bool = True,
) -> tf.keras.callbacks.History:
    """
    Custom fit implementation.

    Interface is intended to mimic best-practice usage of `tf.keras.Model.fit`.

    Unlike `tf.keras.Model.fit` `_train_iter` is added as an attribute to model. If
    using `tf.train.Checkpoint`s to manage training state, this may result in larger
    files on disk.

    Args:
        model: keras model to train.
        train_data: dataset with (inputs, labels) or (inputs, labels, sample_weights)
        epochs: total number of epochs to train until.
        steps_per_epoch: number of steps per epoch. Must be provided if train_data has
            infinite cardinality.
        validation_data: optional dataset to perform validation on.
        validation_steps: number of steps of validation to perform per epoch.
        callbacks: `tf.keras.callbacks.Callback` instances.
        initial_epoch: starting epoch.
        validation_freq: number of epochs between validation.
        track_iterator: if True, `train_data`'s iterator is added as an attribute to
            `model`, meaning it will be saved in checkpoint's saving `model`.
        verbose: controls verbosity of printed output.

    Returns:
        `tf.keras.callbacks.History` object.

    Raises:
        `AttributeError` if `model` has an existing `_train_iter` attribute and
        `track_iterator` is True.
    """
    train_func = model.make_train_function()
    train_iter, steps_per_epoch = as_infinite_iterator(train_data,
                                                       steps_per_epoch)
    if hasattr(model, "_train_iter"):
        raise AttributeError(
            "Cannot fit model with existing `_train_iter` attribute.")
    if track_iterator:
        model._train_iter = train_iter  # pylint: disable=protected-access

    cb = tf.keras.callbacks.CallbackList(callbacks=callbacks,
                                         add_history=True,
                                         add_progbar=verbose,
                                         model=model)
    cb.set_params(
        dict(epochs=epochs, verbose=int(verbose), steps=steps_per_epoch))

    cb.on_train_begin()
    initial_epoch = (
        model._maybe_load_initial_epoch_from_ckpt(  # pylint: disable=protected-access
            initial_epoch))

    training_logs = None
    model.stop_training = False
    for epoch in range(initial_epoch, epochs):
        model.reset_metrics()
        cb.on_epoch_begin(epoch)

        logs = None
        for step in range(steps_per_epoch):
            cb.on_train_batch_begin(step)
            logs = train_func(train_iter)
            cb.on_train_batch_end(step, logs)
            if model.stop_training:
                break
        assert logs is not None
        epoch_logs = logs
        if (validation_data is not None and model._should_eval(  # pylint: disable=protected-access
                epoch, validation_freq)):
            logs = model.evaluate(
                validation_data,
                steps=validation_steps,
                callbacks=cb,
                return_dict=True,
            )
            epoch_logs.update(
                {"val_" + name: val
                 for name, val in logs.items()})
        cb.on_epoch_end(epoch, epoch_logs)
        training_logs = epoch_logs
        if model.stop_training:
            break
    cb.on_train_end(logs=training_logs)
    if track_iterator:
        del model._train_iter
    return model.history