def load(cls, path, **kwargs): r""" Load data fields and model parameters from a pretrained parser. Args: path (str): - a string with the shortcut name of a pre-trained parser defined in supar.PRETRAINED to load from cache or download, e.g., `crf-dep-en`. - a path to a directory containing a pre-trained parser, e.g., `./<path>/model`. kwargs (dict): A dict holding the unconsumed arguments. Returns: The loaded parser. """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' if os.path.exists(path): state = torch.load(path, map_location=args.device) args = state['args'].update(args) args.device = 'cpu' model = cls.MODEL(**args) # print(cls.WORD.embed) # model.load_pretrained(cls.WORD.embed).to(args.device) # parser = cls.load(**args) # parser.model = cls.MODEL(**parser.args) # parser.model.load_pretrained(parser.WORD.embed).to(args.device) # print(parser.WORD.embed) # parser.model.to(args.device) # if os.path.exists(path): # and not args.build: # parser = cls.load(**args) # parser.model = cls.MODEL(**parser.args) # parser.model.load_pretrained(parser.WORD.embed).to(args.device) # return parser # parser = cls.load(**args) # print(parser.CHART) # print(vars(parser.CHART.vocab)) transform = state['transform'] if state['pretrained']: model.load_pretrained(state['pretrained']).to(args.device) else: parser = cls(args, model, transform) model.load_pretrained(parser.WORD.embed).to(args.device) # print(state['state_dict']) model.load_state_dict(state['state_dict']) model.eval() model.to(args.device) parser.model = model parser.args = args parser.transform = transform if parser.args.feat in ('char', 'bert'): parser.WORD, parser.FEAT = parser.transform.WORD else: parser.WORD, parser.FEAT = parser.transform.WORD, parser.transform.POS parser.EDU_BREAK = parser.transform.EDU_BREAK parser.GOLD_METRIC = parser.transform.GOLD_METRIC # self.TREE = self.transform.TREE try: parser.CHART = parser.transform.CHART parser.PARSINGORDER = parser.transform.PARSINGORDER except: print( 'parser.CHART and parser.PARSINGORDER parameters are not available for this model.' ) return parser
def build(cls, path, min_freq=2, fix_len=20, data_path='', **kwargs): """ Build a brand-new Parser, including initialization of all data fields and model parameters. Args: path (str): The path of the model to be saved. min_freq (str): The minimum frequency needed to include a token in the vocabulary. Default: 2. fix_len (int): The max length of all subword pieces. The excess part of each piece will be truncated. Required if using CharLSTM/BERT. Default: 20. kwargs (dict): A dict holding the unconsumed arguments. Returns: The created parser. """ train = os.path.join(data_path, "train_approach1") args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' os.makedirs(os.path.dirname(path), exist_ok=True) if os.path.exists(path) and not args.build: parser = cls.load(**args) parser.model = cls.MODEL(**parser.args) parser.model.load_pretrained(parser.WORD.embed).to(args.device) return parser logger.info("Build the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, eos=eos, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, eos=eos, fix_len=args.fix_len) elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.cls_token or tokenizer.cls_token, eos=tokenizer.sep_token or tokenizer.sep_token, fix_len=args.fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() GOLD_METRIC = RawField('golden_metric') ORIGINAL_EDU_BREAK = RawField('original_edu_break') SENT_BREAK = UnitBreakField('sent_break') EDU_BREAK = UnitBreakField('edu_break') # CHART = ChartDiscourseField('charts_discourse', pad=pad) PARSING_LABEL_TOKEN = ChartDiscourseField('charts_discourse_token') PARSING_LABEL_EDU = ChartDiscourseField('charts_discourse_edu') PARSING_ORDER_EDU = ParsingOrderField('parsing_order_edu') PARSING_ORDER_TOKEN = ParsingOrderField('parsing_order_token') PARSING_ORDER_SELF_POINTING_TOKEN = ParsingOrderField('parsing_order_self_pointing_token') if args.feat in ('char', 'bert'): transform = DiscourseTreeDocEduGold(WORD=(WORD, FEAT), ORIGINAL_EDU_BREAK = ORIGINAL_EDU_BREAK, GOLD_METRIC=GOLD_METRIC, SENT_BREAK=SENT_BREAK, EDU_BREAK=EDU_BREAK, PARSING_LABEL_TOKEN=PARSING_LABEL_TOKEN, PARSING_LABEL_EDU=PARSING_LABEL_EDU, PARSING_ORDER_EDU=PARSING_ORDER_EDU, PARSING_ORDER_TOKEN=PARSING_ORDER_TOKEN, PARSING_ORDER_SELF_POINTING_TOKEN=PARSING_ORDER_SELF_POINTING_TOKEN ) # else: # transform = DiscourseTree(WORD=WORD, EDU_BREAK=EDU_BREAK, GOLD_METRIC=GOLD_METRIC, CHART=CHART, PARSINGORDER=PARSINGORDER) train = Dataset(transform, args.train) WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None)) # WORD.build(train, args.min_freq) FEAT.build(train) PARSING_LABEL_TOKEN.build(train) PARSING_LABEL_EDU.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_labels': len(PARSING_LABEL_TOKEN.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'eos_index': WORD.eos_index, 'feat_pad_index': FEAT.pad_index }) model = cls.MODEL(**args) model.load_pretrained(WORD.embed).to(args.device) return cls(args, model, transform)