def __init__(self, settings: Settings): self.settings = settings if self.settings.multi_seed: self.settings.config.predict_batch_size = 1 self.settings.config.reset_states = True self.model = load_model(settings.config, self.settings.tokenizer) self.delim = settings.config.field_delimiter self._predictions = self._predict_forever()
def train_rnn(params: TrainingParams): """ Fit synthetic data model on training data. This will annotate the training data and create a new file that will be used to actually train on. The updated training data, model, checkpoints, etc will all be saved in the location specified by your config. Args: params: The parameters controlling model training. Returns: None """ store = params.config # TODO: We should check that store is an instance of TensorFlowConfig, but that would currently # load to an import cycle. tokenizer = params.tokenizer num_lines = params.tokenizer_trainer.num_lines text_iter = params.tokenizer_trainer.data_iterator() if not store.overwrite: # pragma: no cover try: load_model(store, tokenizer) except Exception: pass else: raise RuntimeError( "A model already exists in the checkpoint location, you must enable " "overwrite mode or delete the checkpoints first.") total_token_count, validation_dataset, training_dataset = _create_dataset( store, text_iter, num_lines, tokenizer) logging.info("Initializing synthetic model") model = build_model(vocab_size=tokenizer.total_vocab_size, batch_size=store.batch_size, store=store) # Save checkpoints during training checkpoint_prefix = (Path(store.checkpoint_dir) / "synthetic").as_posix() if store.save_all_checkpoints: checkpoint_prefix = checkpoint_prefix + "-{epoch}" checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, save_weights_only=True, monitor=store.best_model_metric, save_best_only=store.save_best_model, ) history_callback = _ModelHistory(total_token_count, store) _callbacks = [checkpoint_callback, history_callback] if store.early_stopping: early_stopping_callback = tf.keras.callbacks.EarlyStopping( monitor=store.best_model_metric, patience=store.early_stopping_patience, restore_best_weights=store.save_best_model, ) _callbacks.append(early_stopping_callback) if store.epoch_callback: _callbacks.append(_EpochCallbackWrapper(store.epoch_callback)) best_val = None try: model.fit(training_dataset, epochs=store.epochs, callbacks=_callbacks, validation_data=validation_dataset) if store.save_best_model: best_val = checkpoint_callback.best if store.early_stopping: # NOTE: In this callback, the "best" attr does not get set in the constructor, so we'll # set it to None if for some reason we can't get it. This also covers a test case that doesn't # run any epochs but accesses this attr. try: best_val = early_stopping_callback.best except AttributeError: best_val = None except (ValueError, IndexError): raise RuntimeError( "Model training failed. Your training data may have too few records in it. " "Please try increasing your training rows and try again") except KeyboardInterrupt: ... _save_history_csv( history_callback, store.checkpoint_dir, store.dp, store.best_model_metric, best_val, ) logging.info( f"Saving model to {tf.train.latest_checkpoint(store.checkpoint_dir)}")
def __init__(self, settings: Settings): self.settings = settings self.model = load_model(settings.config, self.settings.tokenizer) self.delim = settings.config.field_delimiter self._predictions = self._predict_forever()