def test_tokens_dictfeat_contextual(self): # TODO (T65593688): this should be removed after # https://github.com/pytorch/pytorch/pull/33645 is merged. with torch.no_grad(): model = Seq2SeqModel.from_config( Seq2SeqModel.Config( source_embedding=WordEmbedding.Config(embed_dim=512), target_embedding=WordEmbedding.Config(embed_dim=512), inputs=Seq2SeqModel.Config.ModelInput( dict_feat=GazetteerTensorizer.Config( text_column="source_sequence" ), contextual_token_embedding=ByteTokenTensorizer.Config(), ), encoder_decoder=RNNModel.Config( encoder=LSTMSequenceEncoder.Config(embed_dim=619) ), dict_embedding=DictEmbedding.Config(), contextual_token_embedding=ContextualTokenEmbedding.Config( embed_dim=7 ), ), get_tensorizers(add_dict_feat=True, add_contextual_feat=True), ) model.eval() ts_model = model.torchscriptify() res = ts_model( ["call", "mom"], (["call", "mom"], [0.42, 0.17], [4, 3]), [0.42] * (7 * 2), ) assert res is not None
def get_tensorizers(add_dict_feat=False, add_contextual_feat=False): schema = {"source_sequence": str, "dict_feat": Gazetteer, "target_sequence": str} data_source = TSVDataSource.from_config( TSVDataSource.Config( train_filename=TEST_FILE_NAME, field_names=["source_sequence", "dict_feat", "target_sequence"], ), schema, ) src_tensorizer = TokenTensorizer.from_config( TokenTensorizer.Config( column="source_sequence", add_eos_token=True, add_bos_token=True ) ) tgt_tensorizer = TokenTensorizer.from_config( TokenTensorizer.Config( column="target_sequence", add_eos_token=True, add_bos_token=True ) ) tensorizers = {"src_seq_tokens": src_tensorizer, "trg_seq_tokens": tgt_tensorizer} initialize_tensorizers(tensorizers, data_source.train) if add_dict_feat: tensorizers["dict_feat"] = GazetteerTensorizer.from_config( GazetteerTensorizer.Config( text_column="source_sequence", dict_column="dict_feat" ) ) initialize_tensorizers( {"dict_feat": tensorizers["dict_feat"]}, data_source.train ) return tensorizers
def test_tokens_dictfeat(self): model = Seq2SeqModel.from_config( Seq2SeqModel.Config( source_embedding=WordEmbedding.Config(embed_dim=512), target_embedding=WordEmbedding.Config(embed_dim=512), inputs=Seq2SeqModel.Config.ModelInput( dict_feat=GazetteerTensorizer.Config( text_column="source_sequence")), encoder_decoder=RNNModel.Config( encoder=LSTMSequenceEncoder.Config(embed_dim=612)), dict_embedding=DictEmbedding.Config(), ), get_tensorizers(add_dict_feat=True), ) model.eval() ts_model = model.torchscriptify() res = ts_model(["call", "mom"], (["call", "mom"], [0.42, 0.17], [4, 3])) assert res is not None