Esempio n. 1
0
def test_does_model_improve(
    previous_best: Dict[Text, float],
    current_values: Dict[Text, float],
    improved: bool,
    tmpdir: Path,
):
    checkpoint = RasaModelCheckpoint(tmpdir)
    checkpoint.best_metrics_so_far = previous_best
    # true iff all values are equal or better and at least one is better
    assert checkpoint._does_model_improve(current_values) == improved
Esempio n. 2
0
def create_common_callbacks(
    epochs: int,
    tensorboard_log_dir: Optional[Text] = None,
    tensorboard_log_level: Optional[Text] = None,
    checkpoint_dir: Optional[Path] = None,
) -> List["Callback"]:
    """Create common callbacks.

    The following callbacks are created:
    - RasaTrainingLogger callback
    - Optional TensorBoard callback
    - Optional RasaModelCheckpoint callback

    Args:
        epochs: the number of epochs to train
        tensorboard_log_dir: optional directory that should be used for tensorboard
        tensorboard_log_level: defines when training metrics for tensorboard should be
                               logged. Valid values: 'epoch' and 'batch'.
        checkpoint_dir: optional directory that should be used for model checkpointing

    Returns:
        A list of callbacks.
    """
    import tensorflow as tf

    callbacks = [RasaTrainingLogger(epochs, silent=False)]

    if tensorboard_log_dir:
        if tensorboard_log_level == "minibatch":
            tensorboard_log_level = "batch"
            rasa.shared.utils.io.raise_deprecation_warning(
                "You set 'tensorboard_log_level' to 'minibatch'. This value should not "
                "be used anymore. Please use 'batch' instead."
            )

        callbacks.append(
            tf.keras.callbacks.TensorBoard(
                log_dir=tensorboard_log_dir,
                update_freq=tensorboard_log_level,
                write_graph=True,
                write_images=True,
                histogram_freq=10,
            )
        )

    if checkpoint_dir:
        callbacks.append(RasaModelCheckpoint(checkpoint_dir))

    return callbacks
Esempio n. 3
0
def test_on_epoch_end_saves_checkpoints_file(
    previous_best: Dict[Text, float],
    current_values: Dict[Text, float],
    improved: bool,
    tmp_path: Path,
    trained_ted: TEDPolicy,
):
    model_name = "checkpoint"
    best_model_file = tmp_path / model_name
    assert not best_model_file.exists()
    checkpoint = RasaModelCheckpoint(tmp_path)
    checkpoint.best_metrics_so_far = previous_best
    checkpoint.model = trained_ted.model
    checkpoint.on_epoch_end(1, current_values)
    if improved:
        assert best_model_file.exists()
    else:
        assert not best_model_file.exists()