class Config(ConfigBase): data: Data.Config = Data.Config() model: Model.Config trainer: NewTaskTrainer.Config = NewTaskTrainer.Config() optimizer: Optimizer.Config = Adam.Config() scheduler: Scheduler.Config = Scheduler.Config() exporter: Optional[ModelExporter.Config] = None
def test_reset_incremental_states(self): """ This test might seem trivial. However, interacting with the scripted sequence generator crosses the Torchscript boundary, which can lead to weird behavior. If the incremental states don't get properly reset, the model will produce garbage _after_ the first call, which is a pain to debug when you only catch it after training. """ tensorizers = get_tensorizers() # Avoid numeric issues with quantization by setting a known seed. torch.manual_seed(42) model = Seq2SeqModel.from_config( Seq2SeqModel.Config( source_embedding=WordEmbedding.Config(embed_dim=512), target_embedding=WordEmbedding.Config(embed_dim=512), ), tensorizers, ) # Get sample inputs using a data source. schema = { "source_sequence": str, "dict_feat": Gazetteer, "target_sequence": str, } data = Data.from_config( Data.Config(source=TSVDataSource.Config( train_filename=TEST_FILE_NAME, field_names=[ "source_sequence", "dict_feat", "target_sequence" ], )), schema, tensorizers, ) data.batcher = Batcher(1, 1, 1) raw_batch, batch = next( iter(data.batches(Stage.TRAIN, load_early=True))) inputs = model.arrange_model_inputs(batch) model.eval() outputs = model(*inputs) pred, scores = model.get_pred(outputs, {"stage": Stage.TEST}) # Verify that the incremental states reset correctly. decoder = model.sequence_generator.beam_search.decoder_ens decoder.reset_incremental_states() self.assertDictEqual(decoder.incremental_states, {"0": {}}) # Verify that the model returns the same predictions. new_pred, new_scores = model.get_pred(outputs, {"stage": Stage.TEST}) self.assertEqual(new_scores, scores)
def test_force_predictions_on_eval(self): tensorizers = get_tensorizers() model = Seq2SeqModel.from_config( Seq2SeqModel.Config( source_embedding=WordEmbedding.Config(embed_dim=512), target_embedding=WordEmbedding.Config(embed_dim=512), ), tensorizers, ) # Get sample inputs using a data source. schema = { "source_sequence": str, "dict_feat": Gazetteer, "target_sequence": str, } data = Data.from_config( Data.Config(source=TSVDataSource.Config( train_filename=TEST_FILE_NAME, field_names=[ "source_sequence", "dict_feat", "target_sequence" ], )), schema, tensorizers, ) data.batcher = Batcher(1, 1, 1) raw_batch, batch = next( iter(data.batches(Stage.TRAIN, load_early=True))) inputs = model.arrange_model_inputs(batch) # Verify that model does not run sequence generation on prediction. outputs = model(*inputs) pred = model.get_pred(outputs, {"stage": Stage.EVAL}) self.assertEqual(pred, (None, None)) # Verify that attempting to set force_eval_predictions is correctly # accounted for. model.force_eval_predictions = True with self.assertRaises(AssertionError): _ = model.get_pred(outputs, {"stage": Stage.EVAL})
class Config(ConfigBase): data: Data.Config = Data.Config() trainer: NewTaskTrainer.Config = NewTaskTrainer.Config()
class Config(ConfigBase): data: Data.Config = Data.Config() trainer: TaskTrainer.Config = TaskTrainer.Config() # TODO: deprecate this use_elastic: Optional[bool] = None
class Config(ConfigBase): data: Data.Config = Data.Config() model: Model.Config trainer: NewTaskTrainer.Config = NewTaskTrainer.Config() exporter: Optional[ModelExporter.Config] = None
class Config(ConfigBase): data: Data.Config = Data.Config() model: Model.Config trainer: NewTaskTrainer.Config = NewTaskTrainer.Config()
class Config(ConfigBase): data: Data.Config = Data.Config() trainer: TaskTrainer.Config = TaskTrainer.Config() use_elastic: Optional[bool] = None