示例#1
0
    def setUp(self):
        actions_counter = Counter()
        for action in [
            "IN:A",
            "IN:B",
            "IN:UNSUPPORTED",
            "REDUCE",
            "SHIFT",
            "SL:C",
            "SL:D",
        ]:
            actions_counter[action] += 1
        actions_vocab = Vocab(actions_counter, specials=[])

        self.parser = RNNGParser(
            ablation=RNNGParser.Config.AblationParams(),
            constraints=RNNGParser.Config.RNNGConstraints(),
            lstm_num_layers=2,
            lstm_dim=20,
            max_open_NT=10,
            dropout=0.2,
            beam_size=3,
            top_k=3,
            actions_vocab=actions_vocab,
            shift_idx=4,
            reduce_idx=3,
            ignore_subNTs_roots=[2],
            valid_NT_idxs=[0, 1, 2, 5, 6],
            valid_IN_idxs=[0, 1, 2],
            valid_SL_idxs=[5, 6],
            embedding=EmbeddingList(
                embeddings=[
                    WordEmbedding(
                        num_embeddings=5,
                        embedding_dim=20,
                        embeddings_weight=None,
                        init_range=[-1, 1],
                        unk_token_idx=4,
                        mlp_layer_dims=[],
                    ),
                    DictEmbedding(
                        num_embeddings=4, embed_dim=10, pooling_type=PoolingType.MEAN
                    ),
                ],
                concat=True,
            ),
            p_compositional=CompositionalNN(lstm_dim=20),
        )
        self.parser.train()
示例#2
0
文件: tasks.py 项目: LinHR000/pytext
 class Config(Task.Config):
     model: RNNGParser.Config = RNNGParser.Config()
     trainer: HogwildTrainer.Config = HogwildTrainer.Config()
     data_handler: CompositionalDataHandler.Config = CompositionalDataHandler.Config(
     )
     labels: Optional[WordLabelConfig] = None
     metric_reporter: CompositionalMetricReporter.Config = CompositionalMetricReporter.Config(
     )
示例#3
0
 class Config(NewTask.Config):
     model: RNNGParser.Config = RNNGParser.Config()
     trainer: HogwildTrainer.Config = HogwildTrainer.Config()
     metric_reporter: CompositionalMetricReporter.Config = (
         CompositionalMetricReporter.Config())
示例#4
0
class RNNGParserTest(unittest.TestCase):
    def setUp(self):
        actions_counter = Counter()
        for action in [
            "IN:A",
            "IN:B",
            "IN:UNSUPPORTED",
            "REDUCE",
            "SHIFT",
            "SL:C",
            "SL:D",
        ]:
            actions_counter[action] += 1
        actions_vocab = Vocab(actions_counter, specials=[])

        self.parser = RNNGParser(
            ablation=RNNGParser.Config.AblationParams(),
            constraints=RNNGParser.Config.RNNGConstraints(),
            lstm_num_layers=2,
            lstm_dim=20,
            max_open_NT=10,
            dropout=0.2,
            beam_size=3,
            top_k=3,
            actions_vocab=actions_vocab,
            shift_idx=4,
            reduce_idx=3,
            ignore_subNTs_roots=[2],
            valid_NT_idxs=[0, 1, 2, 5, 6],
            valid_IN_idxs=[0, 1, 2],
            valid_SL_idxs=[5, 6],
            embedding=EmbeddingList(
                embeddings=[
                    WordEmbedding(
                        num_embeddings=5,
                        embedding_dim=20,
                        embeddings_weight=None,
                        init_range=[-1, 1],
                        unk_token_idx=4,
                        mlp_layer_dims=[],
                    ),
                    DictEmbedding(
                        num_embeddings=4, embed_dim=10, pooling_type=PoolingType.MEAN
                    ),
                ],
                concat=True,
            ),
            p_compositional=CompositionalNN(lstm_dim=20),
        )
        self.parser.train()

    def populate_buffer(self):
        state = ParserState(self.parser)
        for _ in range(2):
            state.buffer_stackrnn.push(torch.zeros(1, 30), Element("Token"))
        return state

    def check_valid_actions(self, state, actions):
        self.assertSetEqual(set(self.parser.valid_actions(state)), set(actions))

    def test_valid_actions_unconstrained(self):

        self.parser.constraints_intent_slot_nesting = False
        self.parser.constraints_no_slots_inside_unsupported = False
        state = self.populate_buffer()

        # Valid Actions at beginning: all nonterminals
        self.check_valid_actions(state, self.parser.valid_NT_idxs)

        # After pushing IN:A: all nonterminals, SHIFT
        self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:A"])
        self.check_valid_actions(
            state, self.parser.valid_NT_idxs + [self.parser.shift_idx]
        )

        # After pushing SL:C and SHIFT: all nonterminals, SHIFT, REDUCE
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SL:C"])
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"])
        self.check_valid_actions(
            state,
            self.parser.valid_NT_idxs
            + [self.parser.shift_idx]
            + [self.parser.reduce_idx],
        )

        # After all SHIFTs: only REDUCE
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"])
        self.check_valid_actions(state, [self.parser.reduce_idx])

        # After all REDUCEs, no valid actions
        self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"])
        self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"])
        self.check_valid_actions(state, [])

    def test_valid_actions_constraint_insl(self):

        self.parser.constraints_intent_slot_nesting = True
        self.parser.constraints_no_slots_inside_unsupported = False
        state = self.populate_buffer()

        # Valid Actions at beginning: all intents
        self.check_valid_actions(state, self.parser.valid_IN_idxs)

        # After pushing IN:A: all slots, SHIFT
        self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:A"])
        self.check_valid_actions(
            state, self.parser.valid_SL_idxs + [self.parser.shift_idx]
        )

        # After pushing SL:C: all intents, SHIFT
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SL:C"])
        self.check_valid_actions(
            state, self.parser.valid_IN_idxs + [self.parser.shift_idx]
        )

        # After all SHIFTs: only REDUCE
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"])
        self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"])
        self.check_valid_actions(state, [self.parser.reduce_idx])

        # After all REDUCEs, no valid actions
        self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"])
        self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"])
        self.check_valid_actions(state, [])

    def test_valid_actions_constraint_unsupported(self):

        self.parser.constraints_intent_slot_nesting = True
        self.parser.constraints_no_slots_inside_unsupported = True
        state = self.populate_buffer()

        # Valid Actions at beginning: all intents
        self.check_valid_actions(state, self.parser.valid_IN_idxs)

        # Needed to make test logic work
        state.predicted_actions_idx.append(
            self.parser.actions_vocab.stoi["IN:UNSUPPORTED"]
        )
        # After pushing IN:UNSUPPORTED: SHIFT
        self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:UNSUPPORTED"])
        self.check_valid_actions(state, [self.parser.shift_idx])

    def test_forward_shapes(self):
        self.parser.eval(Stage.EVAL)
        tokens = torch.tensor([[0, 1, 2, 3]])
        seq_lens = torch.tensor([tokens.shape[1]])
        dict_feat = (
            torch.tensor([[1, 1, 1, 1, 1, 1, 3, 1]]),
            torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]]),
            torch.tensor([1, 1, 2, 1]),
        )

        actions, scores = self.parser(
            tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat
        )[0]
        self.assertGreater(actions.shape[1], tokens.shape[1])
        self.assertEqual(actions.shape[0:2], scores.shape[0:2])
        self.assertEqual(scores.shape[2], len(self.parser.actions_vocab.itos))

        # Beam Search Test
        self.parser.eval(Stage.TEST)
        results = self.parser(tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat)
        self.assertEqual(len(results), 3)
        for actions, scores in results:
            self.assertGreater(actions.shape[1], tokens.shape[1])
            self.assertEqual(actions.shape[0:2], scores.shape[0:2])
            self.assertEqual(scores.shape[2], len(self.parser.actions_vocab.itos))