Beispiel #1
0
 def test_tokens_dictfeat_contextual(self):
     # TODO (T65593688): this should be removed after
     # https://github.com/pytorch/pytorch/pull/33645 is merged.
     with torch.no_grad():
         model = Seq2SeqModel.from_config(
             Seq2SeqModel.Config(
                 source_embedding=WordEmbedding.Config(embed_dim=512),
                 target_embedding=WordEmbedding.Config(embed_dim=512),
                 inputs=Seq2SeqModel.Config.ModelInput(
                     dict_feat=GazetteerTensorizer.Config(
                         text_column="source_sequence"
                     ),
                     contextual_token_embedding=ByteTokenTensorizer.Config(),
                 ),
                 encoder_decoder=RNNModel.Config(
                     encoder=LSTMSequenceEncoder.Config(embed_dim=619)
                 ),
                 dict_embedding=DictEmbedding.Config(),
                 contextual_token_embedding=ContextualTokenEmbedding.Config(
                     embed_dim=7
                 ),
             ),
             get_tensorizers(add_dict_feat=True, add_contextual_feat=True),
         )
         model.eval()
         ts_model = model.torchscriptify()
         res = ts_model(
             ["call", "mom"],
             (["call", "mom"], [0.42, 0.17], [4, 3]),
             [0.42] * (7 * 2),
         )
         assert res is not None
 def test_tokens_contextual(self):
     model = Seq2SeqModel.from_config(
         Seq2SeqModel.Config(
             source_embedding=WordEmbedding.Config(embed_dim=512),
             target_embedding=WordEmbedding.Config(embed_dim=512),
             inputs=Seq2SeqModel.Config.ModelInput(
                 contextual_token_embedding=ByteTokenTensorizer.Config()),
             contextual_token_embedding=ContextualTokenEmbedding.Config(
                 embed_dim=7),
             encoder_decoder=RNNModel.Config(
                 encoder=LSTMSequenceEncoder.Config(embed_dim=519)),
         ),
         get_tensorizers(add_contextual_feat=True),
     )
     model.eval()
     ts_model = model.torchscriptify()
     res = ts_model(["call", "mom"],
                    contextual_token_embedding=[0.42] * (7 * 2))
     assert res is not None
Beispiel #3
0
 def setUp(self):
     contextual_emb_dim = 1
     emb_module = EmbeddingList(
         embeddings=[
             WordEmbedding(num_embeddings=103, embedding_dim=100),
             DictEmbedding(
                 num_embeddings=59, embed_dim=10, pooling_type=PoolingType.MEAN
             ),
             ContextualTokenEmbedding(contextual_emb_dim),
         ],
         concat=True,
     )
     self.training_model = RNNGModel(
         input_for_trace=RNNGModel.get_input_for_trace(contextual_emb_dim),
         embedding=emb_module,
         ablation=RNNGParser.Config.AblationParams(),
         constraints=RNNGParser.Config.RNNGConstraints(),
         lstm_num_layers=2,
         lstm_dim=32,
         max_open_NT=10,
         dropout=0.4,
         num_actions=20,
         shift_idx=0,
         reduce_idx=1,
         ignore_subNTs_roots=[8, 15],
         valid_NT_idxs=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
         + [12, 13, 14, 15, 16, 17, 18, 19],
         valid_IN_idxs=[2, 4, 7, 8, 10, 12, 13, 14, 15],
         valid_SL_idxs=[3, 5, 6, 9, 11, 16, 17, 18, 19],
         embedding_dim=emb_module.embedding_dim,
         p_compositional=CompositionalNN(lstm_dim=32, device="cpu"),
         device="cpu",
     )
     self.training_model.train()
     self.inference_model = RNNGInference(
         self.training_model.trace_embedding(),
         self.training_model.jit_model,
         MockVocab(["<unk>", "foo", "bar"]),
         MockVocab(["<unk>", "a", "b"]),
         MockVocab(["SHIFT", "REDUCE", "IN:END_CALL", "SL:METHOD_CALL"]),
     )
     self.inference_model.eval()
 def test_tokens_dictfeat_contextual(self):
     model = Seq2SeqModel.from_config(
         Seq2SeqModel.Config(
             source_embedding=WordEmbedding.Config(embed_dim=512),
             target_embedding=WordEmbedding.Config(embed_dim=512),
             inputs=Seq2SeqModel.Config.ModelInput(
                 dict_feat=GazetteerTensorizer.Config(
                     text_column="source_sequence"),
                 contextual_token_embedding=ByteTokenTensorizer.Config(),
             ),
             encoder_decoder=RNNModel.Config(
                 encoder=LSTMSequenceEncoder.Config(embed_dim=619)),
             dict_embedding=DictEmbedding.Config(),
             contextual_token_embedding=ContextualTokenEmbedding.Config(
                 embed_dim=7),
         ),
         get_tensorizers(add_dict_feat=True, add_contextual_feat=True),
     )
     model.eval()
     ts_model = model.torchscriptify()
     res = ts_model(["call", "mom"],
                    (["call", "mom"], [0.42, 0.17], [4, 3]),
                    [0.42] * (7 * 2))
     assert res is not None