def train(store: BaseConfig, tokenizer_trainer: Optional[BaseTokenizerTrainer] = None): """Train a Synthetic Model. This is a facade entrypoint that implements the engine specific training operation based on the provided configuration. Args: store: A subclass instance of ``BaseConfig.`` This config is reponsible for providing the actual training entrypoint for a specific training routine. tokenizer_trainer: An optional subclass instance of a ``BaseTokenizerTrainer``. If provided this tokenizer will be used to pre-process and create an annotated dataset for training. If not provided a default tokenizer will be used. """ if tokenizer_trainer is None: tokenizer_trainer = _create_default_tokenizer(store) tokenizer_trainer.create_annotated_training_data() tokenizer_trainer.train() tokenizer = tokenizer_from_model_dir(store.checkpoint_dir) params = TrainingParams(tokenizer_trainer=tokenizer_trainer, tokenizer=tokenizer, config=store) train_fn = store.get_training_callable() store.save_model_params() store.gpu_check() train_fn(params)
def train_rnn(store: BaseConfig): """ Fit synthetic data model on training data. This will annotate the training data and create a new file that will be used to actually train on. The updated training data, model, checkkpoints, etc will all be saved in the location specified by your config. Args: store: An instance of one of the available configs that you previously created Returns: None """ if not store.overwrite: # pragma: no cover try: _load_model(store) except Exception: pass else: raise RuntimeError( 'A model already exists in the checkpoint location, you must enable overwrite mode or delete the checkpoints first.' ) # noqa text = _annotate_training_data(store) sp = _train_tokenizer(store) dataset = _create_dataset(store, text, sp) logging.info("Initializing synthetic model") model = _build_sequential_model(vocab_size=len(sp), batch_size=store.batch_size, store=store) # Save checkpoints during training checkpoint_prefix = (Path(store.checkpoint_dir) / "synthetic").as_posix() if store.save_all_checkpoints: checkpoint_prefix = checkpoint_prefix + "-{epoch}" checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, save_weights_only=True, monitor='accuracy') history_callback = _LossHistory() model.fit(dataset, epochs=store.epochs, callbacks=[checkpoint_callback, history_callback]) _save_history_csv(history_callback, store.checkpoint_dir) store.save_model_params() logging.info( f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}") if store.dp: logging.info(_compute_epsilon(len(text), store))