Example #1
0
def train_rnn(store: _BaseConfig):
    """
    Fit synthetic data model on training data.

    This will annotate the training data and create a new file that
    will be used to actually train on. The updated training data, model,
    checkkpoints, etc will all be saved in the location specified
    by your config.

    Args:
        store: An instance of one of the available configs that you
            previously created

    Returns:
        None
    """
    if not store.overwrite:  # pragma: no cover
        try:
            _load_model(store)
        except Exception:
            pass
        else:
            raise RuntimeError(
                'A model already exists in the checkpoint location, you must enable overwrite mode or delete the checkpoints first.'
            )  # noqa

    text = _annotate_training_data(store)
    sp = _train_tokenizer(store)
    dataset = _create_dataset(store, text, sp)
    logging.info("Initializing synthetic model")
    model = _build_sequential_model(vocab_size=len(sp),
                                    batch_size=store.batch_size,
                                    store=store)

    # Save checkpoints during training
    checkpoint_prefix = (Path(store.checkpoint_dir) / "synthetic").as_posix()
    if store.save_all_checkpoints:
        checkpoint_prefix = checkpoint_prefix + "-{epoch}"

    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix, save_weights_only=True, monitor='accuracy')
    history_callback = _LossHistory()

    model.fit(dataset,
              epochs=store.epochs,
              callbacks=[checkpoint_callback, history_callback])
    _save_history_csv(history_callback, store.checkpoint_dir)
    store.save_model_params()
    logging.info(
        f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")

    if store.dp:
        logging.info(_compute_epsilon(len(text), store))
Example #2
0
def _prepare_model(sp: spm, batch_size: int,
                   store: _BaseConfig) -> tf.keras.Sequential:
    model = _build_sequential_model(vocab_size=len(sp),
                                    batch_size=batch_size,
                                    store=store)

    load_dir = store.checkpoint_dir

    model.load_weights(tf.train.latest_checkpoint(load_dir)).expect_partial()

    model.build(tf.TensorShape([1, None]))
    model.summary()

    return model