def enumerate(self, semiring): trees = [] for i, length in enumerate(self.lens.tolist()): trees.append([]) for seq in itertools.product(range(length + 1), repeat=length): if not CoNLL.istree(list(seq), True, self.multiroot): continue sibs = self.lens.new_tensor(CoNLL.get_sibs(seq)) sib_mask = sibs.gt(0) s_arc = self.scores[0][i, :length+1, :length+1] s_sib = self.scores[1][i, :length+1, :length+1, :length+1] s_arc = semiring.prod(s_arc[range(1, length + 1), seq], -1) s_sib = semiring.prod(s_sib[1:][sib_mask].gather(-1, sibs[sib_mask].unsqueeze(-1)).squeeze(-1)) trees[-1].append(semiring.mul(s_arc, s_sib)) return [torch.stack(seq) for seq in trees]
def decode(self, s_arc, s_rel, mask, tree=False, proj=False): r""" Args: s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible arcs. s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all possible labels on each arc. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens. tree (bool): If ``True``, ensures to output well-formed trees. Default: ``False``. proj (bool): If ``True``, ensures to output projective trees. Default: ``False``. Returns: ~torch.LongTensor, ~torch.LongTensor: Predicted arcs and labels of shape ``[batch_size, seq_len]``. """ lens = mask.sum(1) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): alg = eisner if proj else mst arc_preds[bad] = alg(s_arc[bad], mask[bad]) rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds
def _predict(self, loader): self.model.eval() preds = {} charts, probs = [], [] for words, *feats in progress_bar(loader): mask = words.ne(self.WORD.pad_index) mask = mask.unsqueeze(1) & mask.unsqueeze(2) mask[:, 0] = 0 lens = mask[:, 1].sum(-1).tolist() s_edge, s_label = self.model(words, feats) edge_preds, label_preds = self.model.decode(s_edge, s_label) chart_preds = label_preds.masked_fill(~(edge_preds.gt(0) & mask), -1) charts.extend(chart[1:i, :i].tolist() for i, chart in zip(lens, chart_preds.unbind())) if self.args.prob: probs.extend([ prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind()) ]) charts = [ CoNLL.build_relations( [[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart]) for chart in charts ] preds = {'labels': charts} if self.args.prob: preds['probs'] = probs return preds
def _predict(self, loader): self.model.eval() preds = {'labels': [], 'probs': [] if self.args.prob else None} for words, *feats in progress_bar(loader): word_mask = words.ne(self.args.pad_index) mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = mask.unsqueeze(1) & mask.unsqueeze(2) mask[:, 0] = 0 lens = mask[:, 1].sum(-1).tolist() s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats) s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask) label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1) preds['labels'].extend(chart[1:i, :i].tolist() for i, chart in zip(lens, label_preds)) if self.args.prob: preds['probs'].extend([ prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind()) ]) preds['labels'] = [ CoNLL.build_relations( [[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart]) for chart in preds['labels'] ] return preds
def enumerate(self, semiring): trees = [] for i, length in enumerate(self.lens.tolist()): trees.append([]) for seq in itertools.product(range(length + 1), repeat=length): if not CoNLL.istree(list(seq), True, self.multiroot): continue trees[-1].append(semiring.prod(self.scores[i, range(1, length + 1), seq], -1)) return [torch.stack(seq) for seq in trees]
def build(cls, path, min_freq=2, fix_len=20, **kwargs): """ Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. Returns: The created parser. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Build the fields") WORD = Field('words', pad=pad, unk=unk, lower=True) CPOS = Field('tags') transform = CoNLL(FORM=WORD, CPOS=CPOS) train = Dataset(transform, args.train) WORD.build( train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None), not_extend_vocab=True) # WORD.build(train, args.min_freq) CPOS.build(train) args.update({ 'n_words': len(WORD.vocab), 'n_cpos': len(CPOS.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, }) model = cls.MODEL(normalize_paras=not args.em_alg, **args) if args.em_alg: model.requires_grad_(False) # model.load_pretrained(WORD.embed).to(args.device) model.to(args.device) return cls(args, model, transform)
def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False): """ Args: s_arc (torch.Tensor): [batch_size, seq_len, seq_len] The scores of all possible arcs. s_sib (torch.Tensor): [batch_size, seq_len, seq_len, seq_len] The scores of all possible dependent-head-sibling triples. s_rel (torch.Tensor): [batch_size, seq_len, seq_len, n_labels] The scores of all possible labels on each arc. mask (torch.BoolTensor): [batch_size, seq_len, seq_len] Mask for covering the unpadded tokens. tree (bool): If True, ensures to output well-formed trees. Default: False. mbr (bool): If True, performs MBR decoding. Default: True. proj (bool): If True, ensures to output projective trees. Default: False. Returns: arc_preds (torch.Tensor): [batch_size, seq_len] The predicted arcs. rel_preds (torch.Tensor): [batch_size, seq_len] The predicted labels. """ lens = mask.sum(1) # prevent self-loops s_arc.diagonal(0, 1, 2).fill_(float('-inf')) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): if proj and not mbr: arc_preds = eisner2o((s_arc, s_sib), mask) else: alg = eisner if proj else mst arc_preds[bad] = alg(s_arc[bad], mask[bad]) rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds
def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False): r""" Args: s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible arcs. s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. Scores of all possible dependent-head-sibling triples. s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all possible labels on each arc. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens. tree (bool): If ``True``, ensures to output well-formed trees. Default: ``False``. mbr (bool): If ``True``, performs MBR decoding. Default: ``True``. proj (bool): If ``True``, ensures to output projective trees. Default: ``False``. Returns: ~torch.LongTensor, ~torch.LongTensor: Predicted arcs and labels of shape ``[batch_size, seq_len]``. """ lens = mask.sum(1) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): if proj and not mbr: arc_preds[bad] = Dependency2oCRF((s_arc[bad], s_sib[bad]), mask[bad].sum(-1)).argmax else: arc_preds[bad] = (DependencyCRF if proj else MatrixTree)( s_arc[bad], mask[bad].sum(-1)).argmax rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds
def _predict(self, loader): self.model.eval() preds = {} charts, probs = [], [] for words, feats in progress_bar(loader): mask = words.ne(self.WORD.pad_index) mask = mask.unsqueeze(1) & mask.unsqueeze(2) lens = mask[:, 0].sum(-1).tolist() s_edge, s_label = self.model(words, feats) charts.extend(self.model.decode(s_edge, s_label, mask)) if self.args.prob: edge_probs = s_edge.softmax(-1) probs.extend([ prob[:i, :i].cpu() for i, prob in zip(lens, edge_probs.unbind()) ]) charts = [[[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart] for chart in charts] preds = {'labels': [CoNLL.build_relations(chart) for chart in charts]} if self.args.prob: preds['probs'] = probs return preds
def build(cls, path, optimizer_args={'lr': 2e-3, 'betas': (.9, .9), 'eps': 1e-12}, scheduler_args={'gamma': .75**(1/5000)}, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. optimizer_args (dict): Arguments for creating an optimizer. scheduler_args (dict): Arguments for creating a scheduler. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len) elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.bos_token or tokenizer.cls_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() else: FEAT = Field('tags', bos=bos) ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) SIB = Field('sibs', bos=bos, use_vocab=False, fn=CoNLL.get_sibs) REL = Field('rels', bos=bos) if args.feat in ('char', 'bert'): transform = CoNLL(FORM=(WORD, FEAT), HEAD=(ARC, SIB), DEPREL=REL) else: transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=(ARC, SIB), DEPREL=REL) train = Dataset(transform, args.train) WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) FEAT.build(train) REL.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_rels': len(REL.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'feat_pad_index': FEAT.pad_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained(WORD.embed).to(args.device) logger.info(f"{model}\n") optimizer = Adam(model.parameters(), **optimizer_args) scheduler = ExponentialLR(optimizer, **scheduler_args) return cls(args, model, transform, optimizer, scheduler)
def build(cls, path, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len) elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.bos_token or tokenizer.cls_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() elif args.feat == 'elmo': logger.info("Hello, initing ElmoField") FEAT = ElmoField('elmo', bos=bos) # else: FEAT = Field('tags', bos=bos) ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) REL = Field('rels', bos=bos) if args.feat in ('char', 'bert'): transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL) elif args.feat == 'elmo': logger.info("calling CoNLL transform") # FEAT ima se kar 3 layerje, to bo za popravit nekak transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL) else: transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL) logger.info("initing train Dataset") train = Dataset(transform, args.train) #WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) logger.info("Building WORD, FEAT, REL fields") WORD.build(train) FEAT.build(train) REL.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_rels': len(REL.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'feat_pad_index': FEAT.pad_index, }) logger.info("Loading model") model = cls.MODEL(**args) model.load_pretrained(WORD.embed).to(args.device) return cls(args, model, transform)
def build(cls, path, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: # 加载已有模型 parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, fix_len=args.fix_len) # 怎么用bert,学习一下 elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() else: FEAT = Field('tags') EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges) LABEL = ChartField('labels', fn=CoNLL.get_labels) # 对于图的边和标签的抽取,定义了新的field if args.feat in ('char', 'bert'): transform = CoNLL(FORM=(WORD, FEAT), PHEAD=(EDGE, LABEL)) else: transform = CoNLL(FORM=WORD, POS=FEAT, PHEAD=(EDGE, LABEL)) train = Dataset(transform, args.train) WORD.build( train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) FEAT.build(train) LABEL.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_labels': len(LABEL.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'feat_pad_index': FEAT.pad_index }) model = cls.MODEL(**args) model.load_pretrained(WORD.embed).to(args.device) return cls(args, model, transform)
def build(cls, path, min_freq=2, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path) or './', exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") TAG, CHAR, BERT = None, None, None if args.encoder != 'lstm': from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) WORD = SubwordField('words', pad=t.pad_token, unk=t.unk_token, bos=t.bos_token or t.cls_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x) WORD.vocab = t.get_vocab() else: WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) if 'tag' in args.feat: TAG = Field('tags', bos=bos) if 'char' in args.feat: CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len) if 'bert' in args.feat: from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) BERT = SubwordField('bert', pad=t.pad_token, unk=t.unk_token, bos=t.bos_token or t.cls_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' '+x) BERT.vocab = t.get_vocab() TEXT = RawField('texts') ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) SIB = ChartField('sibs', bos=bos, use_vocab=False, fn=CoNLL.get_sibs) REL = Field('rels', bos=bos) transform = CoNLL(FORM=(WORD, TEXT, CHAR, BERT), CPOS=TAG, HEAD=(ARC, SIB), DEPREL=REL) train = Dataset(transform, args.train) if args.encoder == 'lstm': WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) if TAG is not None: TAG.build(train) if CHAR is not None: CHAR.build(train) REL.build(train) args.update({ 'n_words': len(WORD.vocab) if args.encoder != 'lstm' else WORD.vocab.n_init, 'n_rels': len(REL.vocab), 'n_tags': len(TAG.vocab) if TAG is not None else None, 'n_chars': len(CHAR.vocab) if CHAR is not None else None, 'char_pad_index': CHAR.pad_index if CHAR is not None else None, 'bert_pad_index': BERT.pad_index if BERT is not None else None, 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None).to(args.device) logger.info(f"{model}\n") return cls(args, model, transform)
def build(cls, path, optimizer_args={ 'lr': 1e-3, 'betas': (.0, .95), 'eps': 1e-12, 'weight_decay': 3e-9 }, scheduler_args={'gamma': .75**(1 / 5000)}, min_freq=7, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. optimizer_args (dict): Arguments for creating an optimizer. scheduler_args (dict): Arguments for creating a scheduler. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default:7. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) TAG, CHAR, LEMMA, BERT = None, None, None, None if 'tag' in args.feat: TAG = Field('tags', bos=bos) if 'char' in args.feat: CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len) if 'lemma' in args.feat: LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True) if 'bert' in args.feat: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) BERT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.bos_token or tokenizer.cls_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) BERT.vocab = tokenizer.get_vocab() EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges) LABEL = ChartField('labels', fn=CoNLL.get_labels) transform = CoNLL(FORM=(WORD, CHAR, BERT), LEMMA=LEMMA, POS=TAG, PHEAD=(EDGE, LABEL)) train = Dataset(transform, args.train) WORD.build( train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) if TAG is not None: TAG.build(train) if CHAR is not None: CHAR.build(train) if LEMMA is not None: LEMMA.build(train) LABEL.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_labels': len(LABEL.vocab), 'n_tags': len(TAG.vocab) if TAG is not None else None, 'n_chars': len(CHAR.vocab) if CHAR is not None else None, 'char_pad_index': CHAR.pad_index if CHAR is not None else None, 'n_lemmas': len(LEMMA.vocab) if LEMMA is not None else None, 'bert_pad_index': BERT.pad_index if BERT is not None else None, 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained(WORD.embed).to(args.device) logger.info(f"{model}\n") optimizer = Adam(model.parameters(), **optimizer_args) scheduler = ExponentialLR(optimizer, **scheduler_args) return cls(args, model, transform, optimizer, scheduler)
def build(cls, path, min_freq=7, fix_len=20, **kwargs): r""" Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default:7. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path) or './', exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Building the fields") WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) TAG, CHAR, LEMMA, ELMO, BERT = None, None, None, None, None if args.encoder == 'bert': from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) WORD = SubwordField( 'words', pad=t.pad_token, unk=t.unk_token, bos=t.bos_token or t.cls_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' ' + x) WORD.vocab = t.get_vocab() else: WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) if 'tag' in args.feat: TAG = Field('tags', bos=BOS) if 'char' in args.feat: CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len) if 'lemma' in args.feat: LEMMA = Field('lemmas', pad=PAD, unk=UNK, bos=BOS, lower=True) if 'elmo' in args.feat: from allennlp.modules.elmo import batch_to_ids ELMO = RawField('elmo') ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) if 'bert' in args.feat: from transformers import (AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast) t = AutoTokenizer.from_pretrained(args.bert) BERT = SubwordField( 'bert', pad=t.pad_token, unk=t.unk_token, bos=t.bos_token or t.cls_token, fix_len=args.fix_len, tokenize=t.tokenize, fn=None if not isinstance(t, (GPT2Tokenizer, GPT2TokenizerFast)) else lambda x: ' ' + x) BERT.vocab = t.get_vocab() LABEL = ChartField('labels', fn=CoNLL.get_labels) transform = CoNLL(FORM=(WORD, CHAR, ELMO, BERT), LEMMA=LEMMA, POS=TAG, PHEAD=LABEL) train = Dataset(transform, args.train) if args.encoder != 'bert': WORD.build( train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) if TAG is not None: TAG.build(train) if CHAR is not None: CHAR.build(train) if LEMMA is not None: LEMMA.build(train) LABEL.build(train) args.update({ 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, 'n_labels': len(LABEL.vocab), 'n_tags': len(TAG.vocab) if TAG is not None else None, 'n_chars': len(CHAR.vocab) if CHAR is not None else None, 'char_pad_index': CHAR.pad_index if CHAR is not None else None, 'n_lemmas': len(LEMMA.vocab) if LEMMA is not None else None, 'bert_pad_index': BERT.pad_index if BERT is not None else None, 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index }) logger.info(f"{transform}") logger.info("Building the model") model = cls.MODEL(**args).load_pretrained( WORD.embed if hasattr(WORD, 'embed') else None).to(args.device) logger.info(f"{model}\n") return cls(args, model, transform)