コード例 #1
0
 def __init__(self, settings: Settings):
     self.settings = settings
     if self.settings.multi_seed:
         self.settings.config.predict_batch_size = 1
         self.settings.config.reset_states = True
     self.model = load_model(settings.config, self.settings.tokenizer)
     self.delim = settings.config.field_delimiter
     self._predictions = self._predict_forever()
コード例 #2
0
ファイル: train.py プロジェクト: gretelai/gretel-synthetics
def train_rnn(params: TrainingParams):
    """
    Fit synthetic data model on training data.

    This will annotate the training data and create a new file that
    will be used to actually train on. The updated training data, model,
    checkpoints, etc will all be saved in the location specified
    by your config.

    Args:
        params: The parameters controlling model training.

    Returns:
        None
    """
    store = params.config
    # TODO: We should check that store is an instance of TensorFlowConfig, but that would currently
    # load to an import cycle.

    tokenizer = params.tokenizer
    num_lines = params.tokenizer_trainer.num_lines
    text_iter = params.tokenizer_trainer.data_iterator()

    if not store.overwrite:  # pragma: no cover
        try:
            load_model(store, tokenizer)
        except Exception:
            pass
        else:
            raise RuntimeError(
                "A model already exists in the checkpoint location, you must enable "
                "overwrite mode or delete the checkpoints first.")

    total_token_count, validation_dataset, training_dataset = _create_dataset(
        store, text_iter, num_lines, tokenizer)
    logging.info("Initializing synthetic model")
    model = build_model(vocab_size=tokenizer.total_vocab_size,
                        batch_size=store.batch_size,
                        store=store)

    # Save checkpoints during training
    checkpoint_prefix = (Path(store.checkpoint_dir) / "synthetic").as_posix()
    if store.save_all_checkpoints:
        checkpoint_prefix = checkpoint_prefix + "-{epoch}"

    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True,
        monitor=store.best_model_metric,
        save_best_only=store.save_best_model,
    )
    history_callback = _ModelHistory(total_token_count, store)

    _callbacks = [checkpoint_callback, history_callback]

    if store.early_stopping:
        early_stopping_callback = tf.keras.callbacks.EarlyStopping(
            monitor=store.best_model_metric,
            patience=store.early_stopping_patience,
            restore_best_weights=store.save_best_model,
        )
        _callbacks.append(early_stopping_callback)

    if store.epoch_callback:
        _callbacks.append(_EpochCallbackWrapper(store.epoch_callback))

    best_val = None
    try:
        model.fit(training_dataset,
                  epochs=store.epochs,
                  callbacks=_callbacks,
                  validation_data=validation_dataset)

        if store.save_best_model:
            best_val = checkpoint_callback.best
        if store.early_stopping:
            # NOTE: In this callback, the "best" attr does not get set in the constructor, so we'll
            # set it to None if for some reason we can't get it. This also covers a test case that doesn't
            # run any epochs but accesses this attr.
            try:
                best_val = early_stopping_callback.best
            except AttributeError:
                best_val = None
    except (ValueError, IndexError):
        raise RuntimeError(
            "Model training failed. Your training data may have too few records in it. "
            "Please try increasing your training rows and try again")
    except KeyboardInterrupt:
        ...
    _save_history_csv(
        history_callback,
        store.checkpoint_dir,
        store.dp,
        store.best_model_metric,
        best_val,
    )
    logging.info(
        f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")
コード例 #3
0
 def __init__(self, settings: Settings):
     self.settings = settings
     self.model = load_model(settings.config, self.settings.tokenizer)
     self.delim = settings.config.field_delimiter
     self._predictions = self._predict_forever()