def setUp(self): actions_counter = Counter() for action in [ "IN:A", "IN:B", "IN:UNSUPPORTED", "REDUCE", "SHIFT", "SL:C", "SL:D", ]: actions_counter[action] += 1 actions_vocab = Vocab(actions_counter, specials=[]) self.parser = RNNGParser( ablation=RNNGParser.Config.AblationParams(), constraints=RNNGParser.Config.RNNGConstraints(), lstm_num_layers=2, lstm_dim=20, max_open_NT=10, dropout=0.2, beam_size=3, top_k=3, actions_vocab=actions_vocab, shift_idx=4, reduce_idx=3, ignore_subNTs_roots=[2], valid_NT_idxs=[0, 1, 2, 5, 6], valid_IN_idxs=[0, 1, 2], valid_SL_idxs=[5, 6], embedding=EmbeddingList( embeddings=[ WordEmbedding( num_embeddings=5, embedding_dim=20, embeddings_weight=None, init_range=[-1, 1], unk_token_idx=4, mlp_layer_dims=[], ), DictEmbedding( num_embeddings=4, embed_dim=10, pooling_type=PoolingType.MEAN ), ], concat=True, ), p_compositional=CompositionalNN(lstm_dim=20), ) self.parser.train()
class Config(Task.Config): model: RNNGParser.Config = RNNGParser.Config() trainer: HogwildTrainer.Config = HogwildTrainer.Config() data_handler: CompositionalDataHandler.Config = CompositionalDataHandler.Config( ) labels: Optional[WordLabelConfig] = None metric_reporter: CompositionalMetricReporter.Config = CompositionalMetricReporter.Config( )
class Config(NewTask.Config): model: RNNGParser.Config = RNNGParser.Config() trainer: HogwildTrainer.Config = HogwildTrainer.Config() metric_reporter: CompositionalMetricReporter.Config = ( CompositionalMetricReporter.Config())
class RNNGParserTest(unittest.TestCase): def setUp(self): actions_counter = Counter() for action in [ "IN:A", "IN:B", "IN:UNSUPPORTED", "REDUCE", "SHIFT", "SL:C", "SL:D", ]: actions_counter[action] += 1 actions_vocab = Vocab(actions_counter, specials=[]) self.parser = RNNGParser( ablation=RNNGParser.Config.AblationParams(), constraints=RNNGParser.Config.RNNGConstraints(), lstm_num_layers=2, lstm_dim=20, max_open_NT=10, dropout=0.2, beam_size=3, top_k=3, actions_vocab=actions_vocab, shift_idx=4, reduce_idx=3, ignore_subNTs_roots=[2], valid_NT_idxs=[0, 1, 2, 5, 6], valid_IN_idxs=[0, 1, 2], valid_SL_idxs=[5, 6], embedding=EmbeddingList( embeddings=[ WordEmbedding( num_embeddings=5, embedding_dim=20, embeddings_weight=None, init_range=[-1, 1], unk_token_idx=4, mlp_layer_dims=[], ), DictEmbedding( num_embeddings=4, embed_dim=10, pooling_type=PoolingType.MEAN ), ], concat=True, ), p_compositional=CompositionalNN(lstm_dim=20), ) self.parser.train() def populate_buffer(self): state = ParserState(self.parser) for _ in range(2): state.buffer_stackrnn.push(torch.zeros(1, 30), Element("Token")) return state def check_valid_actions(self, state, actions): self.assertSetEqual(set(self.parser.valid_actions(state)), set(actions)) def test_valid_actions_unconstrained(self): self.parser.constraints_intent_slot_nesting = False self.parser.constraints_no_slots_inside_unsupported = False state = self.populate_buffer() # Valid Actions at beginning: all nonterminals self.check_valid_actions(state, self.parser.valid_NT_idxs) # After pushing IN:A: all nonterminals, SHIFT self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:A"]) self.check_valid_actions( state, self.parser.valid_NT_idxs + [self.parser.shift_idx] ) # After pushing SL:C and SHIFT: all nonterminals, SHIFT, REDUCE self.parser.push_action(state, self.parser.actions_vocab.stoi["SL:C"]) self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"]) self.check_valid_actions( state, self.parser.valid_NT_idxs + [self.parser.shift_idx] + [self.parser.reduce_idx], ) # After all SHIFTs: only REDUCE self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"]) self.check_valid_actions(state, [self.parser.reduce_idx]) # After all REDUCEs, no valid actions self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"]) self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"]) self.check_valid_actions(state, []) def test_valid_actions_constraint_insl(self): self.parser.constraints_intent_slot_nesting = True self.parser.constraints_no_slots_inside_unsupported = False state = self.populate_buffer() # Valid Actions at beginning: all intents self.check_valid_actions(state, self.parser.valid_IN_idxs) # After pushing IN:A: all slots, SHIFT self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:A"]) self.check_valid_actions( state, self.parser.valid_SL_idxs + [self.parser.shift_idx] ) # After pushing SL:C: all intents, SHIFT self.parser.push_action(state, self.parser.actions_vocab.stoi["SL:C"]) self.check_valid_actions( state, self.parser.valid_IN_idxs + [self.parser.shift_idx] ) # After all SHIFTs: only REDUCE self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"]) self.parser.push_action(state, self.parser.actions_vocab.stoi["SHIFT"]) self.check_valid_actions(state, [self.parser.reduce_idx]) # After all REDUCEs, no valid actions self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"]) self.parser.push_action(state, self.parser.actions_vocab.stoi["REDUCE"]) self.check_valid_actions(state, []) def test_valid_actions_constraint_unsupported(self): self.parser.constraints_intent_slot_nesting = True self.parser.constraints_no_slots_inside_unsupported = True state = self.populate_buffer() # Valid Actions at beginning: all intents self.check_valid_actions(state, self.parser.valid_IN_idxs) # Needed to make test logic work state.predicted_actions_idx.append( self.parser.actions_vocab.stoi["IN:UNSUPPORTED"] ) # After pushing IN:UNSUPPORTED: SHIFT self.parser.push_action(state, self.parser.actions_vocab.stoi["IN:UNSUPPORTED"]) self.check_valid_actions(state, [self.parser.shift_idx]) def test_forward_shapes(self): self.parser.eval(Stage.EVAL) tokens = torch.tensor([[0, 1, 2, 3]]) seq_lens = torch.tensor([tokens.shape[1]]) dict_feat = ( torch.tensor([[1, 1, 1, 1, 1, 1, 3, 1]]), torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]]), torch.tensor([1, 1, 2, 1]), ) actions, scores = self.parser( tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat )[0] self.assertGreater(actions.shape[1], tokens.shape[1]) self.assertEqual(actions.shape[0:2], scores.shape[0:2]) self.assertEqual(scores.shape[2], len(self.parser.actions_vocab.itos)) # Beam Search Test self.parser.eval(Stage.TEST) results = self.parser(tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat) self.assertEqual(len(results), 3) for actions, scores in results: self.assertGreater(actions.shape[1], tokens.shape[1]) self.assertEqual(actions.shape[0:2], scores.shape[0:2]) self.assertEqual(scores.shape[2], len(self.parser.actions_vocab.itos))