Esempio n. 1
0
def test_major_sent(synthese):
    # logger.info('START: model testing...')
    dataset_type = 'synthese' if synthese else 'lead'
    data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type)

    n_iter, total_loss = 0, 0.0
    n_samples, total_hamming = 0, 0.0

    cf_mats, precision_list, recall_list = list(), list(), list()

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        y_true = batch['labels'].cpu().numpy()
        d_batch = len(y_true)
        des_sent_info = batch['des_sent_info'].cpu().numpy()
        n_samples += np.sum(des_sent_info[:, -1])

        # logger.info('batch_size: {0}'.format(y_true.shape[0]))

        if synthese:
            hyp_scores = np.tile(y_pred_vec, (d_batch, 1))
            fids = batch['fids'].cpu().numpy()
            eval_args = {
                'hyp_scores': hyp_scores,
                'fids': fids,
                'is_hiernet': True
            }
            eval_res = metrics.metric_eval_for_syn_doc(**eval_args)
        else:
            hyp_scores = np.tile(y_pred_vec, (d_batch, max_n_sents, 1))
            eval_args = {
                'y_true': y_true,
                'hyp_scores': hyp_scores,
                'des_sent_info': des_sent_info,
            }
            eval_res = metrics.metric_eval(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])
        total_hamming += eval_res['hamming']

    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)
    example_based_f1 = metrics.compute_example_based_f1(
        precision_list=precision_list, recall_list=recall_list)
    hamming = total_hamming / n_samples

    eval_log_info = {
        'example_based_f1': example_based_f1,
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
        'hamming': hamming,
    }

    res_str = 'example_based_f1: {example_based_f1:.6f},' \
              'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}'

    logger.info(res_str.format(**eval_log_info))
Esempio n. 2
0
def test_model_sent_mturk():
    logger.info('START: testing Baseline [MAJOR] on [MTURK SENTS]')

    data_loader = pipe.DomDetDataLoader(dataset_type='mturk')

    n_iter, total_loss = 0, 0.0
    n_samples, total_hamming = 0, 0.0
    cf_mats, precision_list, recall_list = list(), list(), list()

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        y_true = batch['sent_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * n_doms
        d_batch = len(y_true)
        hyp_scores = np.tile(y_pred_vec, (d_batch, 1))
        # hyp_scores = np.tile(y_pred_vec, (d_batch, max_n_sents, 1))

        n_sents = batch['n_sents'].cpu().numpy()
        n_samples += np.sum(n_sents)

        logger.info('batch_size: {0}'.format(y_true.shape[0]))

        eval_args = {
            'y_true': y_true,
            'hyp_scores': hyp_scores,
            'n_sents': n_sents,
            'is_hiernet': True,
        }
        eval_res = metrics.metric_eval_for_mturk(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])
        total_hamming += eval_res['hamming']

    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)
    example_based_f1 = metrics.compute_example_based_f1(
        precision_list=precision_list, recall_list=recall_list)
    hamming = total_hamming / n_samples

    eval_log_info = {
        'example_based_f1': example_based_f1,
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
        'hamming': hamming,
    }

    res_str = 'example_based_f1: {example_based_f1:.6f},' \
              'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}'

    logger.info(res_str.format(**eval_log_info))
Esempio n. 3
0
def get_word_scores(model, n_iter, doc_ids, restore=None):
    if config.placement == 'auto':
        model = nn.DataParallel(model, device_ids=config.device)

    if config.placement in ('auto', 'single'):
        model.cuda()

    checkpoint = join(path_parser.model_save, config.model_name)
    if restore:
        checkpoint = join(checkpoint, 'resume')

    filter_keys = None
    if config.reset_size_for_test and not config.set_sep_des_size:
        filter_keys = [
            'module.word_det.des_ids', 'module.word_det.des_sent_mask',
            'module.word_det.des_word_mask'
        ]
    load_checkpoint(checkpoint=checkpoint,
                    model=model,
                    n_iter=n_iter,
                    filter_keys=filter_keys)
    model.eval()

    dataset_type = 'dev'
    data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type,
                                        collect_doc_ids=True,
                                        doc_ids=doc_ids)

    ws_list = []
    for _, batch in enumerate(data_loader):  # only one batch
        logger.info('Batch size: {}'.format(len(batch)))
        c = copy.deepcopy
        feed_dict = c(batch)

        del feed_dict['doc_ids']

        for (k, v) in feed_dict.items():
            feed_dict[k] = Variable(v, requires_grad=False,
                                    volatile=True)  # fix ids and masks

        feed_dict['return_sent_attn'] = True
        feed_dict['return_word_attn'] = True
        _, _, _, word_scores, _, word_attn = model(**feed_dict)

        # weight sent scores with their attn
        word_scores = word_scores.data.cpu().numpy(
        )  # n_batch * n_sents * n_words * n_doms
        ws_list.append(word_scores)

    ws = np.concatenate(ws_list, axis=0)
    return ws
Esempio n. 4
0
def test_model_word_mturk(matching_mode=None, corpus='wiki'):
    logger.info('START: model testing on [MTURK WORDS]')

    grain = 'word'
    dataset_type = '-'.join(('mturk', corpus, grain))
    data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type)

    n_samples = 0
    p_list = list()
    r_list = list()

    for batch_idx, batch in enumerate(data_loader):
        # turn vars to numpy arrays
        y_true_sents = batch['sent_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * n_doms
        y_true_words = batch['word_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * max_n_words
        n_sents = batch['n_sents'].cpu().numpy()
        n_words = batch['n_words'].cpu().numpy()
        n_samples += np.sum(n_sents)

        d_batch = len(y_true_sents)
        hyp_scores = np.tile(y_pred_vec, (d_batch, 1))

        logger.info('batch_size: {0}'.format(y_true_words.shape[0]))

        eval_args = {
            'hyp_scores': hyp_scores,
            'y_true_sents': y_true_sents,
            'y_true_words': y_true_words,
            'n_sents': n_sents,
            'n_words': n_words,
            'pred_grain': 'doc',
            'max_alter': True,
            'matching_mode': matching_mode,
        }

        eval_res = metrics_word_eval_binary.metric_eval_for_mturk_words_with_ir(
            **eval_args)

        p_list.extend(eval_res['p_list'])
        r_list.extend(eval_res['r_list'])

    exam_f1 = metrics.compute_example_based_f1(p_list, r_list)
    logger.info('word-eval. exam_f1: {0:6f}'.format(exam_f1))
Esempio n. 5
0
    def select_pos_neg_labels_for_en_words(self, corpus):
        if lang == 'en':
            if corpus == 'wiki':
                out_fp = path_parser.label_en_wiki
            elif corpus == 'nyt':
                out_fp = path_parser.label_en_nyt
            else:
                raise ValueError('Invalid corpus with EN: {}'.format(corpus))
        else:
            if corpus == 'wiki':
                out_fp = path_parser.label_zh_wiki
            else:
                raise ValueError('Invalid corpus with ZH: {}'.format(corpus))

        grain = 'word'
        dataset_type = '-'.join(('mturk', corpus, grain))
        data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type)

        records = list()
        for batch_idx, batch in enumerate(data_loader):
            word_labels = batch['word_labels'].cpu().numpy()
            n_sents = batch['n_sents'].cpu().numpy()
            n_words = batch['n_words'].cpu().numpy()

            n_doc = len(word_labels)
            logger.info('n_doc: {}'.format(n_doc))

            for doc_idx in range(n_doc):
                for sent_idx in range(n_sents[doc_idx, 0]):
                    nw = n_words[doc_idx, sent_idx]
                    labels_w = word_labels[doc_idx, sent_idx]
                    pos_w_ids = [str(w_id) for w_id in range(nw) if labels_w[w_id] == 1]
                    pos_str = '|'.join(pos_w_ids)
                    neg_w_ids = [str(w_id) for w_id in range(nw) if str(w_id) not in pos_w_ids]
                    neg_str = '|'.join(neg_w_ids)

                    record = '\t'.join((str(doc_idx), str(sent_idx), pos_str, neg_str))
                    records.append(record)

        logger.info('n_records: {}'.format(len(records)))
        with io.open(out_fp, mode='a', encoding='utf-8') as out_f:
            out_f.write('\n'.join(records))

        logger.info('Selection has been successfully saved to: {}'.format(out_fp))
Esempio n. 6
0
def test_major_doc():
    data_loader = pipe.DomDetDataLoader(dataset_type='test')
    n_iter, total_loss = 0, 0.0
    n_samples, total_hamming = 0, 0.0
    cf_mats, precision_list, recall_list = list(), list(), list()

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        y_true = batch['labels'].cpu().numpy()  # turn vars to numpy arrays
        d_batch = len(y_true)
        y_pred = np.tile(y_pred_vec, (d_batch, 1))

        eval_args = {
            'y_true': y_true,
            'hyp_scores': y_pred,
        }

        n_samples += d_batch
        # logger.info('batch_size: {0}'.format(d_batch))
        eval_res = metrics.metric_eval(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])
        total_hamming += eval_res['hamming']

    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)
    example_based_f1 = metrics.compute_example_based_f1(
        precision_list=precision_list, recall_list=recall_list)
    hamming = total_hamming / n_samples

    eval_log_info = {
        'example_based_f1': example_based_f1,
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
        'hamming': hamming,
    }

    res_str = 'example_based_f1: {example_based_f1:.6f},' \
              'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}'

    logger.info(res_str.format(**eval_log_info))
Esempio n. 7
0
def eval_model(model, phase, save_pred=False, save_gold=False):
    assert phase in ('dev', 'test')
    data_loader = pipe.DomDetDataLoader(dataset_type=phase)
    model.eval()

    n_iter, total_loss = 0, 0.0
    n_samples, total_hamming = 0, 0.0
    cf_mats, precision_list, recall_list = list(), list(), list()

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        c = copy.deepcopy
        feed_dict = c(batch)

        for (k, v) in feed_dict.items():
            feed_dict[k] = Variable(v, requires_grad=False,
                                    volatile=True)  # fix ids and masks

        loss, doc_scores = model(**feed_dict)[:2]
        total_loss += loss.data[0]

        y_true = batch['labels'].cpu().numpy()  # turn vars to numpy arrays
        hyp_scores = doc_scores.data.cpu().numpy()
        eval_args = {
            'y_true': y_true,
            'hyp_scores': hyp_scores,
        }

        if save_pred:
            eval_args['save_pred_to'] = join(path_parser.pred_doc,
                                             config_loader.meta_model_name)

        if save_gold:
            eval_args['save_true_to'] = join(path_parser.pred_doc, 'gold')

        # del model_res
        n_samples += y_true.shape[0]
        # logger.info('batch_size: {0}'.format(y_true.shape[0]))
        eval_res = metrics.metric_eval(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])
        total_hamming += eval_res['hamming']

    avg_loss = total_loss / n_iter
    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)
    example_based_f1 = metrics.compute_example_based_f1(
        precision_list=precision_list, recall_list=recall_list)
    hamming = total_hamming / n_samples

    eval_log_info = {
        'ph': phase,
        'loss': avg_loss,
        'example_based_f1': example_based_f1,
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
        'hamming': hamming,
    }

    return eval_log_info
Esempio n. 8
0
def test_model_sent_syn_with_checkpoints(model,
                                         save_pred=False,
                                         save_gold=False,
                                         n_iter=None,
                                         restore=False):
    if config_loader.placement == 'auto':
        model = nn.DataParallel(model, device_ids=config_loader.device)

    if config_loader.placement in ('auto', 'single'):
        model.cuda()

    logger.info('START: model testing on [SENTS with SYNTHETIC CONTEXT]')

    checkpoint = join(path_parser.model_save, config_loader.model_name)
    if restore:
        checkpoint = join(checkpoint, 'resume')

    filter_keys = None
    if config_loader.reset_size_for_test and not config_loader.set_sep_des_size:
        logger.info('Filter DES pretrained paras...')
        filter_keys = [
            'module.word_det.des_ids', 'module.word_det.des_sent_mask',
            'module.word_det.des_word_mask'
        ]
    load_checkpoint(checkpoint=checkpoint,
                    model=model,
                    n_iter=n_iter,
                    filter_keys=filter_keys)

    data_loader = pipe.DomDetDataLoader(dataset_type='synthese')
    model.eval()

    n_iter, total_loss = 0, 0.0
    n_samples, total_hamming = 0, 0.0

    cf_mats, precision_list, recall_list = list(), list(), list()

    is_hiernet = True if config_loader.meta_model_name == 'hiernet' else False
    if config_loader.meta_model_name in ('detnet1', 'milnet'):
        no_word_scores = True
    else:
        no_word_scores = False

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        c = copy.deepcopy
        feed_dict = c(batch)

        # del paras not for forward()
        del feed_dict['des_sent_info']
        del feed_dict['fids']

        for (k, v) in feed_dict.items():
            feed_dict[k] = Variable(v, requires_grad=False,
                                    volatile=True)  # fix ids and masks

        if is_hiernet:
            loss, doc_scores = model(**feed_dict)
            hyp_scores = doc_scores.data.cpu().numpy()
        elif no_word_scores:
            loss, doc_scores, sent_scores = model(**feed_dict)
            hyp_scores = sent_scores.data.cpu().numpy()
        else:
            loss, doc_scores, sent_scores, word_scores = model(**feed_dict)
            hyp_scores = sent_scores.data.cpu().numpy()

        total_loss += loss.data[0]
        # turn vars to numpy arrays
        y_true = batch['labels'].cpu().numpy()
        des_sent_info = batch['des_sent_info'].cpu().numpy()
        n_samples += np.sum(des_sent_info[:, -1])

        fids = batch['fids'].cpu().numpy()
        eval_args = {
            'hyp_scores': hyp_scores,
            'fids': fids,
            'is_hiernet': is_hiernet,
        }

        if save_pred:
            pred_save_fp = join(path_parser.pred_syn,
                                config_loader.meta_model_name)
            eval_args['save_pred_to'] = pred_save_fp

        if save_gold:
            true_save_fp = join(path_parser.pred_syn, 'gold')
            eval_args['save_true_to'] = true_save_fp

        eval_res = metrics.metric_eval_for_syn_doc(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])
        total_hamming += eval_res['hamming']

    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)

    eval_log_info = {
        'ph': 'Test',
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
    }

    res_str = build_res_str(stage=None,
                            use_loss=False,
                            use_exam_f1=False,
                            use_hamming=False)

    logger.info(res_str.format(**eval_log_info))
Esempio n. 9
0
def test_model_sent_mturk_with_checkpoints(model,
                                           corpus='wiki',
                                           save_pred=False,
                                           save_gold=False,
                                           n_iter=None,
                                           restore=False):
    if corpus == 'nyt' and lang != 'en':
        raise ('Set lang to en when NYT corpus is used')

    if config_loader.placement == 'auto':
        model = nn.DataParallel(model, device_ids=config_loader.device)

    if config_loader.placement in ('auto', 'single'):
        model.cuda()

    logger.info('START: model testing on [SENTS with MTURK]')

    checkpoint = join(path_parser.model_save, config_loader.model_name)
    if restore:
        checkpoint = join(checkpoint, 'resume')

    filter_keys = None
    if config_loader.reset_size_for_test and not config_loader.set_sep_des_size:
        logger.info('Filter DES pretrained paras...')
        filter_keys = [
            'module.word_det.des_ids', 'module.word_det.des_sent_mask',
            'module.word_det.des_word_mask'
        ]
    load_checkpoint(checkpoint=checkpoint,
                    model=model,
                    n_iter=n_iter,
                    filter_keys=filter_keys)

    grain = 'sent'
    dataset_type = '-'.join(('mturk', corpus, grain))
    data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type)
    model.eval()

    n_iter, total_loss = 0, 0.0
    n_samples = 0

    cf_mats, precision_list, recall_list = list(), list(), list()

    is_hiernet = True if config_loader.meta_model_name == 'hiernet' else False
    if config_loader.meta_model_name in ('detnet1', 'milnet'):
        no_word_scores = True
    else:
        no_word_scores = False

    for batch_idx, batch in enumerate(data_loader):
        n_iter += 1

        c = copy.deepcopy
        feed_dict = c(batch)

        del feed_dict['n_sents']
        del feed_dict['sent_labels']
        for (k, v) in feed_dict.items():
            feed_dict[k] = Variable(v, requires_grad=False,
                                    volatile=True)  # fix ids and masks

        if is_hiernet:
            loss, doc_scores = model(**feed_dict)
            hyp_scores = doc_scores.data.cpu().numpy()
        elif no_word_scores:
            loss, doc_scores, sent_scores = model(**feed_dict)
            hyp_scores = sent_scores.data.cpu().numpy()
        else:
            loss, doc_scores, sent_scores, word_scores = model(**feed_dict)
            hyp_scores = sent_scores.data.cpu().numpy()

            # for doc_id, score in enumerate(hyp_scores):
            #     for sent_id, s in enumerate(score):
            #         logger.info('{0}.{1}: {2}'.format(doc_id, sent_id, score))

        total_loss += loss.data[0]
        # turn vars to numpy arrays
        y_true = batch['sent_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * n_doms
        n_sents = batch['n_sents'].cpu().numpy()
        # logger.info('n_sents: {0}'.format(n_sents))
        n_samples += np.sum(n_sents)

        # logger.info('batch_size: {0}'.format(y_true.shape[0]))

        eval_args = {
            'y_true': y_true,
            'hyp_scores': hyp_scores,
            'n_sents': n_sents,
            'is_hiernet': is_hiernet,
        }

        save_dir = path_parser.pred_mturk_wiki if corpus == 'wiki' else path_parser.pred_mturk_nyt
        if save_pred:
            pred_save_fp = join(save_dir, config_loader.meta_model_name)
            eval_args['save_pred_to'] = pred_save_fp

        if save_gold:
            true_save_fp = join(save_dir, 'gold')
            eval_args['save_true_to'] = true_save_fp

        eval_res = metrics.metric_eval_for_mturk(**eval_args)

        cf_mats.append(eval_res['cf_mat_list'])
        precision_list.extend(eval_res['precision_list'])
        recall_list.extend(eval_res['recall_list'])

    cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats)

    eval_log_info = {
        'ph': 'Test',
        'avg_f1': avg_f1,
        'cls_f1': cls_f1,
    }

    res_str = build_res_str(stage=None,
                            use_loss=False,
                            use_exam_f1=False,
                            use_hamming=False)
    logger.info(res_str.format(**eval_log_info))
Esempio n. 10
0
def test_model_word_mturk_with_checkpoints(model,
                                           matching_mode=None,
                                           corpus='wiki',
                                           save_pred=False,
                                           save_gold=False,
                                           n_iter=None,
                                           restore=False):
    if corpus == 'wiki':
        save_dir = path_parser.pred_mturk_wiki
    elif corpus == 'nyt':
        if lang != 'en':
            raise ValueError('Set lang to en when NYT corpus is used')
        save_dir = path_parser.pred_mturk_nyt
    else:
        raise ValueError('Invalid corpus: {}'.format(corpus))

    if config_loader.placement == 'auto':
        model = nn.DataParallel(model, device_ids=config_loader.device)

    if config_loader.placement in ('auto', 'single'):
        model.cuda()

    logger.info('START: model testing on [MTURK WORDS]')

    checkpoint = join(path_parser.model_save, config_loader.model_name)
    if restore:
        checkpoint = join(checkpoint, 'resume')

    filter_keys = None
    if config_loader.reset_size_for_test and not config_loader.set_sep_des_size:
        logger.info('Filter DES pretrained paras...')
        filter_keys = [
            'module.word_det.des_ids', 'module.word_det.des_sent_mask',
            'module.word_det.des_word_mask'
        ]
    load_checkpoint(checkpoint=checkpoint,
                    model=model,
                    n_iter=n_iter,
                    filter_keys=filter_keys)

    grain = 'word'
    dataset_type = '-'.join(('mturk', corpus, grain))
    data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type)
    model.eval()

    c = copy.deepcopy
    pred_grain = get_model_pred_grain()
    p_list = list()
    r_list = list()
    y_true_sents_list = list()
    n_sents_list = list()

    for batch_idx, batch in enumerate(data_loader):

        feed_dict = c(batch)

        del feed_dict['sent_labels']
        del feed_dict['word_labels']
        del feed_dict['n_sents']
        del feed_dict['n_words']

        for (k, v) in feed_dict.items():
            feed_dict[k] = Variable(v, requires_grad=False,
                                    volatile=True)  # fix ids and masks

        if pred_grain == 'doc':
            _, doc_scores = model(**feed_dict)
            hyp_scores = doc_scores.data.cpu().numpy()
        elif pred_grain == 'sent':
            _, _, sent_scores = model(**feed_dict)
            hyp_scores = sent_scores.data.cpu().numpy()
        elif pred_grain == 'word':
            feed_dict['return_sent_attn'] = True
            feed_dict['return_word_attn'] = True
            _, _, _, word_scores, _, word_attn = model(**feed_dict)

            hyp_scores = word_scores.data.cpu().numpy(
            )  # n_batch * n_sents * n_words * n_doms
        else:
            raise ValueError('Invalid prediction grain: {}'.format(pred_grain))

        # turn vars to numpy arrays
        y_true_sents = batch['sent_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * n_doms
        y_true_words = batch['word_labels'].cpu().numpy(
        )  # d_batch * max_n_sents * max_n_words
        n_sents = batch['n_sents'].cpu().numpy()
        n_words = batch['n_words'].cpu().numpy()

        logger.info('batch_size: {0}'.format(y_true_words.shape[0]))

        eval_args = {
            'hyp_scores': hyp_scores,
            'y_true_sents': y_true_sents,
            'y_true_words': y_true_words,
            'n_sents': n_sents,
            'n_words': n_words,
            'pred_grain': pred_grain,
            'max_alter': True,
            'matching_mode': matching_mode,
        }

        if save_pred:
            fn = '_'.join((grain, config_loader.meta_model_name))
            pred_save_fp = join(save_dir, fn)
            eval_args['save_pred_to'] = pred_save_fp

        if save_gold:
            fn = '_'.join((grain, 'gold'))
            true_save_fp = join(save_dir, fn)
            eval_args['save_true_to'] = true_save_fp

        eval_res = metrics_word_eval_binary.metric_eval_for_mturk_words_with_ir(
            **eval_args)

        p_list.extend(eval_res['p_list'])
        r_list.extend(eval_res['r_list'])
        y_true_sents_list.append(y_true_sents)
        n_sents_list.append(n_sents)

    exam_f1 = metrics.compute_example_based_f1(p_list, r_list)
    logger.info('word-eval. exam_f1: {0:6f}'.format(exam_f1))

    report_dom_specific_f1(p_list, r_list, y_true_sents_list[0],
                           n_sents_list[0])
Esempio n. 11
0
def train_model_with_checkpoints(model,
                                 restore=False,
                                 batch_log=True,
                                 batch_eval=True):
    logger.info('Config: {0}'.format(config_loader.model_name))
    max_n_iter = config_model['n_batches']

    checkpoint = join(path_parser.model_save, config_loader.model_name)
    performance_save_fp = join(path_parser.performances,
                               '{0}.txt'.format(config_loader.model_name))

    data_loader = pipe.DomDetDataLoader(dataset_type='train')
    optimizer = make_opt(model)

    if config_loader.placement == 'auto':
        model = nn.DataParallel(model, device_ids=config_loader.device)

    if config_loader.placement in ('auto', 'single'):
        model.cuda()

    if restore:
        global_n_iter = load_checkpoint(checkpoint,
                                        model=model,
                                        optimizer=optimizer,
                                        no_iter_strategy='last')
        global_n_iter -= 1  # backward compatible
        logger.info(
            'MODE: restore a pre-trained model and resume training from {}'.
            format(global_n_iter))
        checkpoint = join(checkpoint, 'resume')
        max_n_iter += global_n_iter
    else:
        logger.info('MODE: create a new model')
        global_n_iter = 0

    train_skip_iter = 50  # 50
    eval_skip_iter = 200

    checkpoint_dict = dict()  # {n_batches: f1}
    batch_loss = 0.0
    for epoch_idx in range(config_model['n_epochs']):
        for batch_idx, batch in enumerate(data_loader):
            model.train(mode=True)
            global_n_iter += 1

            feed_dict = copy.deepcopy(batch)

            for (k, v) in feed_dict.items():
                feed_dict[k] = Variable(
                    v, requires_grad=False)  # fix ids and masks

            loss = model(**feed_dict)[0]

            if config_loader.placement == 'auto':
                loss = loss.mean()  # gather loss from multiple gpus

            batch_loss += loss.data[0]

            optimizer.zero_grad()  # clear history gradients

            loss.backward(retain_graph=True)

            if config_model['grad_clip'] is not None:
                nn.utils.clip_grad_norm(model.parameters(),
                                        config_model['grad_clip'])

            optimizer.step()

            if batch_log and global_n_iter % train_skip_iter == 0:
                check_grads(model)
                logger.info('iter: {0}, loss: {1:.6f}'.format(
                    global_n_iter, loss.data[0]))

            if batch_eval and global_n_iter % eval_skip_iter == 0:
                eval_log_info = eval_model(model=model, phase='dev')
                eval_log_info['n_iter'] = global_n_iter

                # res_str = build_res_str(stage='n_iter')
                # logger.info(res_str.format(**eval_log_info))

                mean_batch_loss = batch_loss / eval_skip_iter
                eval_loss = eval_log_info['loss']
                f1 = eval_log_info['avg_f1']

                checkpoint_dict, update, is_best = update_checkpoint_dict(
                    checkpoint_dict, k=global_n_iter, v=f1)
                perf_rec = '{0}\t{1:.6f}\t{2:.6f}\t{3:.6f}\n'.format(
                    global_n_iter, mean_batch_loss, eval_loss, f1)

                state = {
                    'n_iters': global_n_iter + 1,
                    'state_dict': model.state_dict(),
                    'optimizer_dict': optimizer.state_dict()
                },
                if update:
                    save_checkpoint(state, checkpoint, global_n_iter, is_best)
                    clean_outdated_checkpoints(checkpoint, checkpoint_dict)

                with open(performance_save_fp, 'a', encoding='utf-8') as f:
                    f.write(perf_rec)
                logger.info(perf_rec.strip('\n'))

                batch_loss = 0.0

            if global_n_iter == max_n_iter:
                break

        if global_n_iter == max_n_iter:
            logger.info(
                'finished expected training: {0} batches!'.format(max_n_iter))
            break