コード例 #1
0
ファイル: predict.py プロジェクト: sunjiao123sun/coref
def init_models(config, device):
    span_repr = SpanEmbedder(config, device).to(device)
    span_repr.load_state_dict(torch.load(os.path.join(config['model_path'],
                                                      "span_repr_{}".format(config['model_num'])),
                                         map_location=device))
    span_repr.eval()
    span_scorer = SpanScorer(config).to(device)
    span_scorer.load_state_dict(torch.load(os.path.join(config['model_path'],
                                                        "span_scorer_{}".format(config['model_num'])),
                                           map_location=device))
    span_scorer.eval()
    pairwise_scorer = SimplePairWiseClassifier(config).to(device)
    pairwise_scorer.load_state_dict(torch.load(os.path.join(config['model_path'],
                                                           "pairwise_scorer_{}".format(config['model_num'])),
                                              map_location=device))
    pairwise_scorer.eval()

    return span_repr, span_scorer, pairwise_scorer
コード例 #2
0
    create_folder(config.model_path)

    # init train and dev set
    train = CrossEncoderDatasetFull(config, 'train')
    train_loader = data.DataLoader(train, batch_size=config.batch_size, shuffle=True)
    dev = CrossEncoderDatasetFull(config, 'dev')
    dev_loader = data.DataLoader(dev, batch_size=config.batch_size, shuffle=False)

    device_ids = config.gpu_num
    device = torch.device("cuda:{}".format(device_ids[0]))


    ## Models' initiation
    logger.info('Init models')
    span_repr = SpanEmbedder(config, device).to(device)
    span_scorer = SpanScorer(config).to(device)
    cross_encoder_single = FullCrossEncoder(config).to(device)
    cross_encoder = torch.nn.DataParallel(cross_encoder_single, device_ids=device_ids)


    if config.training_method in ('pipeline', 'continue') and not config.use_gold_mentions:
        span_repr.load_state_dict(torch.load(config.span_repr_path, map_location=device))
        span_scorer.load_state_dict(torch.load(config.span_scorer_path, map_location=device))


    ## Optimizer and loss function
    criterion = get_loss_function(config)
    optimizer = get_optimizer(config, [cross_encoder])
    scheduler = get_scheduler(optimizer, total_steps=config.epochs * len(train_loader))

コード例 #3
0
        torch.cuda.set_device(config['gpu_num'])
    else:
        device = 'cpu'

    # read and tokenize data
    bert_tokenizer = AutoTokenizer.from_pretrained(config['bert_model'],
                                                   add_special_tokens=True)
    training_set = create_corpus(config, bert_tokenizer, 'train')
    dev_set = create_corpus(config, bert_tokenizer, 'dev')

    # Mention extractor configuration
    logger.info('Init models')
    bert_model = AutoModel.from_pretrained(config['bert_model']).to(device)
    config['bert_hidden_size'] = bert_model.config.hidden_size
    span_repr = SpanEmbedder(config, device).to(device)
    span_scorer = SpanScorer(config).to(device)
    optimizer = get_optimizer(config, [span_scorer, span_repr])
    criterion = get_loss_function(config)

    logger.info('Number of parameters of mention extractor: {}'.format(
        count_parameters(span_repr) + count_parameters(span_scorer)))

    span_repr_path = os.path.join(
        config['model_path'], '{}_span_repr_{}'.format(config['mention_type'],
                                                       config['exp_num']))
    span_scorer_path = os.path.join(
        config['model_path'],
        '{}_span_scorer_{}'.format(config['mention_type'], config['exp_num']))

    logger.info('Number of topics: {}'.format(len(training_set.topic_list)))
    max_dev = (0, None)
コード例 #4
0
    device = torch.device('cuda:{}'.format(
        config.gpu_num[0])) if torch.cuda.is_available() else 'cpu'

    logger.info('Using device {}'.format(device))
    # init train and dev set
    bert_tokenizer = AutoTokenizer.from_pretrained(config['bert_model'])
    training_set = create_corpus(config, bert_tokenizer, 'train')
    dev_set = create_corpus(config, bert_tokenizer, 'dev')

    ## Model initiation
    logger.info('Init models')
    bert_model = AutoModel.from_pretrained(config['bert_model']).to(device)
    config['bert_hidden_size'] = bert_model.config.hidden_size

    span_repr = SpanEmbedder(config, device).to(device)
    span_scorer = SpanScorer(config).to(device)

    if config['training_method'] in (
            'pipeline', 'continue') and not config['use_gold_mentions']:
        span_repr.load_state_dict(
            torch.load(config['span_repr_path'], map_location=device))
        span_scorer.load_state_dict(
            torch.load(config['span_scorer_path'], map_location=device))

    span_repr.eval()
    span_scorer.eval()

    pairwise_model = SimplePairWiseClassifier(config).to(device)

    ## Optimizer and loss function
    models = [pairwise_model]