def _evaluate(self, loader): self.model.eval() total_loss, metric = 0, SpanMetric() for batch in loader: words, *feats, trees, charts = batch word_mask = words.ne(self.args.pad_index)[:, 1:] mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) s_span, s_pair, s_label = self.model(words, feats) loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask) chart_preds = self.model.decode(s_span, s_label, mask) # since the evaluation relies on terminals, # the tree should be first built and then factorized preds = [ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ] total_loss += loss.item() metric([ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds ], [ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees ]) total_loss /= len(loader) return total_loss, metric
def _evaluate(self, loader): self.model.eval() total_loss, metric = 0, BracketMetric() for words, feats, trees, (spans, labels) in loader: batch_size, seq_len = words.shape lens = words.ne(self.args.pad_index).sum(1) - 1 mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1) mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1) s_span, s_label = self.model(words, feats) loss, s_span = self.model.loss(s_span, s_label, spans, labels, mask, self.args.mbr) chart_preds = self.model.decode(s_span, s_label, mask) # since the evaluation relies on terminals, # the tree should be first built and then factorized preds = [ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ] total_loss += loss.item() metric([ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds ], [ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees ]) total_loss /= len(loader) return total_loss, metric
def _predict(self, loader): self.model.eval() preds = {'trees': [], 'probs': [] if self.args.prob else None} for batch in progress_bar(loader): words, *feats, trees = batch word_mask = words.ne(self.args.pad_index)[:, 1:] mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) lens = mask[:, 0].sum(-1) s_span, s_label = self.model(words, feats) s_span = ConstituencyCRF( s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span chart_preds = self.model.decode(s_span, s_label, mask) preds['trees'].extend([ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ]) if self.args.prob: preds['probs'].extend( [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span)]) return preds
def _predict(self, loader): self.model.eval() preds, probs = {'trees': []}, [] for words, feats, trees in progress_bar(loader): batch_size, seq_len = words.shape lens = words.ne(self.args.pad_index).sum(1) - 1 mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1) mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1) s_span, s_label = self.model(words, feats) if self.args.mbr: s_span = self.model.crf(s_span, mask, mbr=True) chart_preds = self.model.decode(s_span, s_label, mask) preds['trees'].extend([ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ]) if self.args.prob: probs.extend([ prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span.unbind()) ]) if self.args.prob: preds['probs'] = probs return preds
def build(cls, path, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path) or './', exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True) TAG, CHAR, ELMO, BERT = None, None, None, None if args.encoder == 'bert': from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) WORD = SubwordField( 'words', pad=t.pad_token, unk=t.unk_token, bos=t.cls_token or t.cls_token, eos=t.sep_token or t.sep_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' ' + x) WORD.vocab = t.get_vocab() else: WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True) if 'tag' in args.feat: TAG = Field('tags', bos=BOS, eos=EOS) if 'char' in args.feat: CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len) if 'elmo' in args.feat: from allennlp.modules.elmo import batch_to_ids ELMO = RawField('elmo') ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) if 'bert' in args.feat: from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) BERT = SubwordField( 'bert', pad=t.pad_token, unk=t.unk_token, bos=t.cls_token or t.cls_token, eos=t.sep_token or t.sep_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' ' + x) BERT.vocab = t.get_vocab() TREE = RawField('trees') CHART = ChartField('charts') transform = Tree(WORD=(WORD, CHAR, ELMO, BERT), POS=TAG, TREE=TREE, CHART=CHART) train = Dataset(transform, args.train) if args.encoder != 'bert': WORD.build( train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) if TAG is not None: TAG.build(train) if CHAR is not None: CHAR.build(train) CHART.build(train) args.update({ 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, 'n_labels': len(CHART.vocab), 'n_tags': len(TAG.vocab) if TAG is not None else None, 'n_chars': len(CHAR.vocab) if CHAR is not None else None, 'char_pad_index': CHAR.pad_index if CHAR is not None else None, 'bert_pad_index': BERT.pad_index if BERT is not None else None, 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'eos_index': WORD.eos_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained( WORD.embed if hasattr(WORD, 'embed') else None).to(args.device) logger.info(f"{model}\n") return cls(args, model, transform)
def build(cls, path, optimizer_args={'lr': 2e-3, 'betas': (.9, .9), 'eps': 1e-12}, scheduler_args={'gamma': .75**(1/5000)}, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. optimizer_args (dict): Arguments for creating an optimizer. scheduler_args (dict): Arguments for creating a scheduler. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, eos=eos, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, eos=eos, fix_len=args.fix_len) elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.cls_token or tokenizer.cls_token, eos=tokenizer.sep_token or tokenizer.sep_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() else: FEAT = Field('tags', bos=bos, eos=eos) TREE = RawField('trees') CHART = ChartField('charts') if args.feat in ('char', 'bert'): transform = Tree(WORD=(WORD, FEAT), TREE=TREE, CHART=CHART) else: transform = Tree(WORD=WORD, POS=FEAT, TREE=TREE, CHART=CHART) train = Dataset(transform, args.train) WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) FEAT.build(train) CHART.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_labels': len(CHART.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'eos_index': WORD.eos_index, 'feat_pad_index': FEAT.pad_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained(WORD.embed).to(args.device) logger.info(f"{model}\n") optimizer = Adam(model.parameters(), **optimizer_args) scheduler = ExponentialLR(optimizer, **scheduler_args) return cls(args, model, transform, optimizer, scheduler)