コード例 #1
0
 def test_tokens_dictfeat_contextual(self):
     # TODO (T65593688): this should be removed after
     # https://github.com/pytorch/pytorch/pull/33645 is merged.
     with torch.no_grad():
         model = Seq2SeqModel.from_config(
             Seq2SeqModel.Config(
                 source_embedding=WordEmbedding.Config(embed_dim=512),
                 target_embedding=WordEmbedding.Config(embed_dim=512),
                 inputs=Seq2SeqModel.Config.ModelInput(
                     dict_feat=GazetteerTensorizer.Config(
                         text_column="source_sequence"
                     ),
                     contextual_token_embedding=ByteTokenTensorizer.Config(),
                 ),
                 encoder_decoder=RNNModel.Config(
                     encoder=LSTMSequenceEncoder.Config(embed_dim=619)
                 ),
                 dict_embedding=DictEmbedding.Config(),
                 contextual_token_embedding=ContextualTokenEmbedding.Config(
                     embed_dim=7
                 ),
             ),
             get_tensorizers(add_dict_feat=True, add_contextual_feat=True),
         )
         model.eval()
         ts_model = model.torchscriptify()
         res = ts_model(
             ["call", "mom"],
             (["call", "mom"], [0.42, 0.17], [4, 3]),
             [0.42] * (7 * 2),
         )
         assert res is not None
コード例 #2
0
ファイル: rnng_test.py プロジェクト: rutyrinott/pytext-1
    def setUp(self):
        actions_counter = Counter()
        for action in [
            "IN:A",
            "IN:B",
            "IN:UNSUPPORTED",
            "REDUCE",
            "SHIFT",
            "SL:C",
            "SL:D",
        ]:
            actions_counter[action] += 1
        actions_vocab = Vocab(actions_counter, specials=[])

        self.parser = RNNGParser(
            ablation=RNNGParser.Config.AblationParams(),
            constraints=RNNGParser.Config.RNNGConstraints(),
            lstm_num_layers=2,
            lstm_dim=20,
            max_open_NT=10,
            dropout=0.2,
            beam_size=3,
            top_k=3,
            actions_vocab=actions_vocab,
            shift_idx=4,
            reduce_idx=3,
            ignore_subNTs_roots=[2],
            valid_NT_idxs=[0, 1, 2, 5, 6],
            valid_IN_idxs=[0, 1, 2],
            valid_SL_idxs=[5, 6],
            embedding=EmbeddingList(
                embeddings=[
                    WordEmbedding(
                        num_embeddings=5,
                        embedding_dim=20,
                        embeddings_weight=None,
                        init_range=[-1, 1],
                        unk_token_idx=4,
                        mlp_layer_dims=[],
                    ),
                    DictEmbedding(
                        num_embeddings=4, embed_dim=10, pooling_type=PoolingType.MEAN
                    ),
                ],
                concat=True,
            ),
            p_compositional=CompositionalNN(lstm_dim=20),
        )
        self.parser.train()
コード例 #3
0
 def test_tokens_dictfeat(self):
     model = Seq2SeqModel.from_config(
         Seq2SeqModel.Config(
             source_embedding=WordEmbedding.Config(embed_dim=512),
             target_embedding=WordEmbedding.Config(embed_dim=512),
             inputs=Seq2SeqModel.Config.ModelInput(
                 dict_feat=GazetteerTensorizer.Config(
                     text_column="source_sequence")),
             encoder_decoder=RNNModel.Config(
                 encoder=LSTMSequenceEncoder.Config(embed_dim=612)),
             dict_embedding=DictEmbedding.Config(),
         ),
         get_tensorizers(add_dict_feat=True),
     )
     model.eval()
     ts_model = model.torchscriptify()
     res = ts_model(["call", "mom"],
                    (["call", "mom"], [0.42, 0.17], [4, 3]))
     assert res is not None
コード例 #4
0
 def setUp(self):
     contextual_emb_dim = 1
     emb_module = EmbeddingList(
         embeddings=[
             WordEmbedding(num_embeddings=103, embedding_dim=100),
             DictEmbedding(
                 num_embeddings=59, embed_dim=10, pooling_type=PoolingType.MEAN
             ),
             ContextualTokenEmbedding(contextual_emb_dim),
         ],
         concat=True,
     )
     self.training_model = RNNGModel(
         input_for_trace=RNNGModel.get_input_for_trace(contextual_emb_dim),
         embedding=emb_module,
         ablation=RNNGParser.Config.AblationParams(),
         constraints=RNNGParser.Config.RNNGConstraints(),
         lstm_num_layers=2,
         lstm_dim=32,
         max_open_NT=10,
         dropout=0.4,
         num_actions=20,
         shift_idx=0,
         reduce_idx=1,
         ignore_subNTs_roots=[8, 15],
         valid_NT_idxs=[2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
         + [12, 13, 14, 15, 16, 17, 18, 19],
         valid_IN_idxs=[2, 4, 7, 8, 10, 12, 13, 14, 15],
         valid_SL_idxs=[3, 5, 6, 9, 11, 16, 17, 18, 19],
         embedding_dim=emb_module.embedding_dim,
         p_compositional=CompositionalNN(lstm_dim=32, device="cpu"),
         device="cpu",
     )
     self.training_model.train()
     self.inference_model = RNNGInference(
         self.training_model.trace_embedding(),
         self.training_model.jit_model,
         MockVocab(["<unk>", "foo", "bar"]),
         MockVocab(["<unk>", "a", "b"]),
         MockVocab(["SHIFT", "REDUCE", "IN:END_CALL", "SL:METHOD_CALL"]),
     )
     self.inference_model.eval()