def test_does_model_improve( previous_best: Dict[Text, float], current_values: Dict[Text, float], improved: bool, tmpdir: Path, ): checkpoint = RasaModelCheckpoint(tmpdir) checkpoint.best_metrics_so_far = previous_best # true iff all values are equal or better and at least one is better assert checkpoint._does_model_improve(current_values) == improved
def create_common_callbacks( epochs: int, tensorboard_log_dir: Optional[Text] = None, tensorboard_log_level: Optional[Text] = None, checkpoint_dir: Optional[Path] = None, ) -> List["Callback"]: """Create common callbacks. The following callbacks are created: - RasaTrainingLogger callback - Optional TensorBoard callback - Optional RasaModelCheckpoint callback Args: epochs: the number of epochs to train tensorboard_log_dir: optional directory that should be used for tensorboard tensorboard_log_level: defines when training metrics for tensorboard should be logged. Valid values: 'epoch' and 'batch'. checkpoint_dir: optional directory that should be used for model checkpointing Returns: A list of callbacks. """ import tensorflow as tf callbacks = [RasaTrainingLogger(epochs, silent=False)] if tensorboard_log_dir: if tensorboard_log_level == "minibatch": tensorboard_log_level = "batch" rasa.shared.utils.io.raise_deprecation_warning( "You set 'tensorboard_log_level' to 'minibatch'. This value should not " "be used anymore. Please use 'batch' instead." ) callbacks.append( tf.keras.callbacks.TensorBoard( log_dir=tensorboard_log_dir, update_freq=tensorboard_log_level, write_graph=True, write_images=True, histogram_freq=10, ) ) if checkpoint_dir: callbacks.append(RasaModelCheckpoint(checkpoint_dir)) return callbacks
def test_on_epoch_end_saves_checkpoints_file( previous_best: Dict[Text, float], current_values: Dict[Text, float], improved: bool, tmp_path: Path, trained_ted: TEDPolicy, ): model_name = "checkpoint" best_model_file = tmp_path / model_name assert not best_model_file.exists() checkpoint = RasaModelCheckpoint(tmp_path) checkpoint.best_metrics_so_far = previous_best checkpoint.model = trained_ted.model checkpoint.on_epoch_end(1, current_values) if improved: assert best_model_file.exists() else: assert not best_model_file.exists()