def init_models(config, device): span_repr = SpanEmbedder(config, device).to(device) span_repr.load_state_dict(torch.load(os.path.join(config['model_path'], "span_repr_{}".format(config['model_num'])), map_location=device)) span_repr.eval() span_scorer = SpanScorer(config).to(device) span_scorer.load_state_dict(torch.load(os.path.join(config['model_path'], "span_scorer_{}".format(config['model_num'])), map_location=device)) span_scorer.eval() pairwise_scorer = SimplePairWiseClassifier(config).to(device) pairwise_scorer.load_state_dict(torch.load(os.path.join(config['model_path'], "pairwise_scorer_{}".format(config['model_num'])), map_location=device)) pairwise_scorer.eval() return span_repr, span_scorer, pairwise_scorer
bert_tokenizer = AutoTokenizer.from_pretrained(config['bert_model']) training_set = create_corpus(config, bert_tokenizer, 'train') dev_set = create_corpus(config, bert_tokenizer, 'dev') ## Model initiation logger.info('Init models') bert_model = AutoModel.from_pretrained(config['bert_model']).to(device) config['bert_hidden_size'] = bert_model.config.hidden_size span_repr = SpanEmbedder(config, device).to(device) span_scorer = SpanScorer(config).to(device) if config['training_method'] in ('pipeline', 'continue') and not config['use_gold_mentions']: span_repr.load_state_dict(torch.load(config['span_repr_path'], map_location=device)) span_scorer.load_state_dict(torch.load(config['span_scorer_path'], map_location=device)) span_repr.eval() span_scorer.eval() pairwise_model = SimplePairWiseClassifier(config).to(device) ## Optimizer and loss function models = [pairwise_model] if config['training_method'] in ('continue', 'e2e') and not config['use_gold_mentions']: models.append(span_repr) models.append(span_scorer) optimizer = get_optimizer(config, models) criterion = get_loss_function(config)