コード例 #1
0
def train(store: BaseConfig,
          tokenizer_trainer: Optional[BaseTokenizerTrainer] = None):
    """Train a Synthetic Model.  This is a facade entrypoint that implements the engine
    specific training operation based on the provided configuration.

    Args:
        store: A subclass instance of ``BaseConfig.`` This config is reponsible for
            providing the actual training entrypoint for a specific training routine.

        tokenizer_trainer: An optional subclass instance of a ``BaseTokenizerTrainer``.  If provided
            this tokenizer will be used to pre-process and create an annotated dataset for training.
            If not provided a default tokenizer will be used.
    """
    if tokenizer_trainer is None:
        tokenizer_trainer = _create_default_tokenizer(store)
    tokenizer_trainer.create_annotated_training_data()
    tokenizer_trainer.train()
    tokenizer = tokenizer_from_model_dir(store.checkpoint_dir)
    params = TrainingParams(tokenizer_trainer=tokenizer_trainer,
                            tokenizer=tokenizer,
                            config=store)
    train_fn = store.get_training_callable()
    store.save_model_params()
    store.gpu_check()
    train_fn(params)
コード例 #2
0
ファイル: train.py プロジェクト: eminentli/gretel-synthetics
def train_rnn(store: BaseConfig):
    """
    Fit synthetic data model on training data.

    This will annotate the training data and create a new file that
    will be used to actually train on. The updated training data, model,
    checkkpoints, etc will all be saved in the location specified
    by your config.

    Args:
        store: An instance of one of the available configs that you
            previously created

    Returns:
        None
    """
    if not store.overwrite:  # pragma: no cover
        try:
            _load_model(store)
        except Exception:
            pass
        else:
            raise RuntimeError(
                'A model already exists in the checkpoint location, you must enable overwrite mode or delete the checkpoints first.'
            )  # noqa

    text = _annotate_training_data(store)
    sp = _train_tokenizer(store)
    dataset = _create_dataset(store, text, sp)
    logging.info("Initializing synthetic model")
    model = _build_sequential_model(vocab_size=len(sp),
                                    batch_size=store.batch_size,
                                    store=store)

    # Save checkpoints during training
    checkpoint_prefix = (Path(store.checkpoint_dir) / "synthetic").as_posix()
    if store.save_all_checkpoints:
        checkpoint_prefix = checkpoint_prefix + "-{epoch}"

    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix, save_weights_only=True, monitor='accuracy')
    history_callback = _LossHistory()

    model.fit(dataset,
              epochs=store.epochs,
              callbacks=[checkpoint_callback, history_callback])
    _save_history_csv(history_callback, store.checkpoint_dir)
    store.save_model_params()
    logging.info(
        f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")

    if store.dp:
        logging.info(_compute_epsilon(len(text), store))