def test_major_sent(synthese): # logger.info('START: model testing...') dataset_type = 'synthese' if synthese else 'lead' data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type) n_iter, total_loss = 0, 0.0 n_samples, total_hamming = 0, 0.0 cf_mats, precision_list, recall_list = list(), list(), list() for batch_idx, batch in enumerate(data_loader): n_iter += 1 y_true = batch['labels'].cpu().numpy() d_batch = len(y_true) des_sent_info = batch['des_sent_info'].cpu().numpy() n_samples += np.sum(des_sent_info[:, -1]) # logger.info('batch_size: {0}'.format(y_true.shape[0])) if synthese: hyp_scores = np.tile(y_pred_vec, (d_batch, 1)) fids = batch['fids'].cpu().numpy() eval_args = { 'hyp_scores': hyp_scores, 'fids': fids, 'is_hiernet': True } eval_res = metrics.metric_eval_for_syn_doc(**eval_args) else: hyp_scores = np.tile(y_pred_vec, (d_batch, max_n_sents, 1)) eval_args = { 'y_true': y_true, 'hyp_scores': hyp_scores, 'des_sent_info': des_sent_info, } eval_res = metrics.metric_eval(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) total_hamming += eval_res['hamming'] cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) example_based_f1 = metrics.compute_example_based_f1( precision_list=precision_list, recall_list=recall_list) hamming = total_hamming / n_samples eval_log_info = { 'example_based_f1': example_based_f1, 'avg_f1': avg_f1, 'cls_f1': cls_f1, 'hamming': hamming, } res_str = 'example_based_f1: {example_based_f1:.6f},' \ 'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}' logger.info(res_str.format(**eval_log_info))
def test_model_sent_mturk(): logger.info('START: testing Baseline [MAJOR] on [MTURK SENTS]') data_loader = pipe.DomDetDataLoader(dataset_type='mturk') n_iter, total_loss = 0, 0.0 n_samples, total_hamming = 0, 0.0 cf_mats, precision_list, recall_list = list(), list(), list() for batch_idx, batch in enumerate(data_loader): n_iter += 1 y_true = batch['sent_labels'].cpu().numpy( ) # d_batch * max_n_sents * n_doms d_batch = len(y_true) hyp_scores = np.tile(y_pred_vec, (d_batch, 1)) # hyp_scores = np.tile(y_pred_vec, (d_batch, max_n_sents, 1)) n_sents = batch['n_sents'].cpu().numpy() n_samples += np.sum(n_sents) logger.info('batch_size: {0}'.format(y_true.shape[0])) eval_args = { 'y_true': y_true, 'hyp_scores': hyp_scores, 'n_sents': n_sents, 'is_hiernet': True, } eval_res = metrics.metric_eval_for_mturk(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) total_hamming += eval_res['hamming'] cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) example_based_f1 = metrics.compute_example_based_f1( precision_list=precision_list, recall_list=recall_list) hamming = total_hamming / n_samples eval_log_info = { 'example_based_f1': example_based_f1, 'avg_f1': avg_f1, 'cls_f1': cls_f1, 'hamming': hamming, } res_str = 'example_based_f1: {example_based_f1:.6f},' \ 'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}' logger.info(res_str.format(**eval_log_info))
def get_word_scores(model, n_iter, doc_ids, restore=None): if config.placement == 'auto': model = nn.DataParallel(model, device_ids=config.device) if config.placement in ('auto', 'single'): model.cuda() checkpoint = join(path_parser.model_save, config.model_name) if restore: checkpoint = join(checkpoint, 'resume') filter_keys = None if config.reset_size_for_test and not config.set_sep_des_size: filter_keys = [ 'module.word_det.des_ids', 'module.word_det.des_sent_mask', 'module.word_det.des_word_mask' ] load_checkpoint(checkpoint=checkpoint, model=model, n_iter=n_iter, filter_keys=filter_keys) model.eval() dataset_type = 'dev' data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type, collect_doc_ids=True, doc_ids=doc_ids) ws_list = [] for _, batch in enumerate(data_loader): # only one batch logger.info('Batch size: {}'.format(len(batch))) c = copy.deepcopy feed_dict = c(batch) del feed_dict['doc_ids'] for (k, v) in feed_dict.items(): feed_dict[k] = Variable(v, requires_grad=False, volatile=True) # fix ids and masks feed_dict['return_sent_attn'] = True feed_dict['return_word_attn'] = True _, _, _, word_scores, _, word_attn = model(**feed_dict) # weight sent scores with their attn word_scores = word_scores.data.cpu().numpy( ) # n_batch * n_sents * n_words * n_doms ws_list.append(word_scores) ws = np.concatenate(ws_list, axis=0) return ws
def test_model_word_mturk(matching_mode=None, corpus='wiki'): logger.info('START: model testing on [MTURK WORDS]') grain = 'word' dataset_type = '-'.join(('mturk', corpus, grain)) data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type) n_samples = 0 p_list = list() r_list = list() for batch_idx, batch in enumerate(data_loader): # turn vars to numpy arrays y_true_sents = batch['sent_labels'].cpu().numpy( ) # d_batch * max_n_sents * n_doms y_true_words = batch['word_labels'].cpu().numpy( ) # d_batch * max_n_sents * max_n_words n_sents = batch['n_sents'].cpu().numpy() n_words = batch['n_words'].cpu().numpy() n_samples += np.sum(n_sents) d_batch = len(y_true_sents) hyp_scores = np.tile(y_pred_vec, (d_batch, 1)) logger.info('batch_size: {0}'.format(y_true_words.shape[0])) eval_args = { 'hyp_scores': hyp_scores, 'y_true_sents': y_true_sents, 'y_true_words': y_true_words, 'n_sents': n_sents, 'n_words': n_words, 'pred_grain': 'doc', 'max_alter': True, 'matching_mode': matching_mode, } eval_res = metrics_word_eval_binary.metric_eval_for_mturk_words_with_ir( **eval_args) p_list.extend(eval_res['p_list']) r_list.extend(eval_res['r_list']) exam_f1 = metrics.compute_example_based_f1(p_list, r_list) logger.info('word-eval. exam_f1: {0:6f}'.format(exam_f1))
def select_pos_neg_labels_for_en_words(self, corpus): if lang == 'en': if corpus == 'wiki': out_fp = path_parser.label_en_wiki elif corpus == 'nyt': out_fp = path_parser.label_en_nyt else: raise ValueError('Invalid corpus with EN: {}'.format(corpus)) else: if corpus == 'wiki': out_fp = path_parser.label_zh_wiki else: raise ValueError('Invalid corpus with ZH: {}'.format(corpus)) grain = 'word' dataset_type = '-'.join(('mturk', corpus, grain)) data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type) records = list() for batch_idx, batch in enumerate(data_loader): word_labels = batch['word_labels'].cpu().numpy() n_sents = batch['n_sents'].cpu().numpy() n_words = batch['n_words'].cpu().numpy() n_doc = len(word_labels) logger.info('n_doc: {}'.format(n_doc)) for doc_idx in range(n_doc): for sent_idx in range(n_sents[doc_idx, 0]): nw = n_words[doc_idx, sent_idx] labels_w = word_labels[doc_idx, sent_idx] pos_w_ids = [str(w_id) for w_id in range(nw) if labels_w[w_id] == 1] pos_str = '|'.join(pos_w_ids) neg_w_ids = [str(w_id) for w_id in range(nw) if str(w_id) not in pos_w_ids] neg_str = '|'.join(neg_w_ids) record = '\t'.join((str(doc_idx), str(sent_idx), pos_str, neg_str)) records.append(record) logger.info('n_records: {}'.format(len(records))) with io.open(out_fp, mode='a', encoding='utf-8') as out_f: out_f.write('\n'.join(records)) logger.info('Selection has been successfully saved to: {}'.format(out_fp))
def test_major_doc(): data_loader = pipe.DomDetDataLoader(dataset_type='test') n_iter, total_loss = 0, 0.0 n_samples, total_hamming = 0, 0.0 cf_mats, precision_list, recall_list = list(), list(), list() for batch_idx, batch in enumerate(data_loader): n_iter += 1 y_true = batch['labels'].cpu().numpy() # turn vars to numpy arrays d_batch = len(y_true) y_pred = np.tile(y_pred_vec, (d_batch, 1)) eval_args = { 'y_true': y_true, 'hyp_scores': y_pred, } n_samples += d_batch # logger.info('batch_size: {0}'.format(d_batch)) eval_res = metrics.metric_eval(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) total_hamming += eval_res['hamming'] cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) example_based_f1 = metrics.compute_example_based_f1( precision_list=precision_list, recall_list=recall_list) hamming = total_hamming / n_samples eval_log_info = { 'example_based_f1': example_based_f1, 'avg_f1': avg_f1, 'cls_f1': cls_f1, 'hamming': hamming, } res_str = 'example_based_f1: {example_based_f1:.6f},' \ 'avg_f1: {avg_f1:.6f}, cls_f1: {cls_f1}, hamming: {hamming:.6f}' logger.info(res_str.format(**eval_log_info))
def eval_model(model, phase, save_pred=False, save_gold=False): assert phase in ('dev', 'test') data_loader = pipe.DomDetDataLoader(dataset_type=phase) model.eval() n_iter, total_loss = 0, 0.0 n_samples, total_hamming = 0, 0.0 cf_mats, precision_list, recall_list = list(), list(), list() for batch_idx, batch in enumerate(data_loader): n_iter += 1 c = copy.deepcopy feed_dict = c(batch) for (k, v) in feed_dict.items(): feed_dict[k] = Variable(v, requires_grad=False, volatile=True) # fix ids and masks loss, doc_scores = model(**feed_dict)[:2] total_loss += loss.data[0] y_true = batch['labels'].cpu().numpy() # turn vars to numpy arrays hyp_scores = doc_scores.data.cpu().numpy() eval_args = { 'y_true': y_true, 'hyp_scores': hyp_scores, } if save_pred: eval_args['save_pred_to'] = join(path_parser.pred_doc, config_loader.meta_model_name) if save_gold: eval_args['save_true_to'] = join(path_parser.pred_doc, 'gold') # del model_res n_samples += y_true.shape[0] # logger.info('batch_size: {0}'.format(y_true.shape[0])) eval_res = metrics.metric_eval(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) total_hamming += eval_res['hamming'] avg_loss = total_loss / n_iter cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) example_based_f1 = metrics.compute_example_based_f1( precision_list=precision_list, recall_list=recall_list) hamming = total_hamming / n_samples eval_log_info = { 'ph': phase, 'loss': avg_loss, 'example_based_f1': example_based_f1, 'avg_f1': avg_f1, 'cls_f1': cls_f1, 'hamming': hamming, } return eval_log_info
def test_model_sent_syn_with_checkpoints(model, save_pred=False, save_gold=False, n_iter=None, restore=False): if config_loader.placement == 'auto': model = nn.DataParallel(model, device_ids=config_loader.device) if config_loader.placement in ('auto', 'single'): model.cuda() logger.info('START: model testing on [SENTS with SYNTHETIC CONTEXT]') checkpoint = join(path_parser.model_save, config_loader.model_name) if restore: checkpoint = join(checkpoint, 'resume') filter_keys = None if config_loader.reset_size_for_test and not config_loader.set_sep_des_size: logger.info('Filter DES pretrained paras...') filter_keys = [ 'module.word_det.des_ids', 'module.word_det.des_sent_mask', 'module.word_det.des_word_mask' ] load_checkpoint(checkpoint=checkpoint, model=model, n_iter=n_iter, filter_keys=filter_keys) data_loader = pipe.DomDetDataLoader(dataset_type='synthese') model.eval() n_iter, total_loss = 0, 0.0 n_samples, total_hamming = 0, 0.0 cf_mats, precision_list, recall_list = list(), list(), list() is_hiernet = True if config_loader.meta_model_name == 'hiernet' else False if config_loader.meta_model_name in ('detnet1', 'milnet'): no_word_scores = True else: no_word_scores = False for batch_idx, batch in enumerate(data_loader): n_iter += 1 c = copy.deepcopy feed_dict = c(batch) # del paras not for forward() del feed_dict['des_sent_info'] del feed_dict['fids'] for (k, v) in feed_dict.items(): feed_dict[k] = Variable(v, requires_grad=False, volatile=True) # fix ids and masks if is_hiernet: loss, doc_scores = model(**feed_dict) hyp_scores = doc_scores.data.cpu().numpy() elif no_word_scores: loss, doc_scores, sent_scores = model(**feed_dict) hyp_scores = sent_scores.data.cpu().numpy() else: loss, doc_scores, sent_scores, word_scores = model(**feed_dict) hyp_scores = sent_scores.data.cpu().numpy() total_loss += loss.data[0] # turn vars to numpy arrays y_true = batch['labels'].cpu().numpy() des_sent_info = batch['des_sent_info'].cpu().numpy() n_samples += np.sum(des_sent_info[:, -1]) fids = batch['fids'].cpu().numpy() eval_args = { 'hyp_scores': hyp_scores, 'fids': fids, 'is_hiernet': is_hiernet, } if save_pred: pred_save_fp = join(path_parser.pred_syn, config_loader.meta_model_name) eval_args['save_pred_to'] = pred_save_fp if save_gold: true_save_fp = join(path_parser.pred_syn, 'gold') eval_args['save_true_to'] = true_save_fp eval_res = metrics.metric_eval_for_syn_doc(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) total_hamming += eval_res['hamming'] cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) eval_log_info = { 'ph': 'Test', 'avg_f1': avg_f1, 'cls_f1': cls_f1, } res_str = build_res_str(stage=None, use_loss=False, use_exam_f1=False, use_hamming=False) logger.info(res_str.format(**eval_log_info))
def test_model_sent_mturk_with_checkpoints(model, corpus='wiki', save_pred=False, save_gold=False, n_iter=None, restore=False): if corpus == 'nyt' and lang != 'en': raise ('Set lang to en when NYT corpus is used') if config_loader.placement == 'auto': model = nn.DataParallel(model, device_ids=config_loader.device) if config_loader.placement in ('auto', 'single'): model.cuda() logger.info('START: model testing on [SENTS with MTURK]') checkpoint = join(path_parser.model_save, config_loader.model_name) if restore: checkpoint = join(checkpoint, 'resume') filter_keys = None if config_loader.reset_size_for_test and not config_loader.set_sep_des_size: logger.info('Filter DES pretrained paras...') filter_keys = [ 'module.word_det.des_ids', 'module.word_det.des_sent_mask', 'module.word_det.des_word_mask' ] load_checkpoint(checkpoint=checkpoint, model=model, n_iter=n_iter, filter_keys=filter_keys) grain = 'sent' dataset_type = '-'.join(('mturk', corpus, grain)) data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type) model.eval() n_iter, total_loss = 0, 0.0 n_samples = 0 cf_mats, precision_list, recall_list = list(), list(), list() is_hiernet = True if config_loader.meta_model_name == 'hiernet' else False if config_loader.meta_model_name in ('detnet1', 'milnet'): no_word_scores = True else: no_word_scores = False for batch_idx, batch in enumerate(data_loader): n_iter += 1 c = copy.deepcopy feed_dict = c(batch) del feed_dict['n_sents'] del feed_dict['sent_labels'] for (k, v) in feed_dict.items(): feed_dict[k] = Variable(v, requires_grad=False, volatile=True) # fix ids and masks if is_hiernet: loss, doc_scores = model(**feed_dict) hyp_scores = doc_scores.data.cpu().numpy() elif no_word_scores: loss, doc_scores, sent_scores = model(**feed_dict) hyp_scores = sent_scores.data.cpu().numpy() else: loss, doc_scores, sent_scores, word_scores = model(**feed_dict) hyp_scores = sent_scores.data.cpu().numpy() # for doc_id, score in enumerate(hyp_scores): # for sent_id, s in enumerate(score): # logger.info('{0}.{1}: {2}'.format(doc_id, sent_id, score)) total_loss += loss.data[0] # turn vars to numpy arrays y_true = batch['sent_labels'].cpu().numpy( ) # d_batch * max_n_sents * n_doms n_sents = batch['n_sents'].cpu().numpy() # logger.info('n_sents: {0}'.format(n_sents)) n_samples += np.sum(n_sents) # logger.info('batch_size: {0}'.format(y_true.shape[0])) eval_args = { 'y_true': y_true, 'hyp_scores': hyp_scores, 'n_sents': n_sents, 'is_hiernet': is_hiernet, } save_dir = path_parser.pred_mturk_wiki if corpus == 'wiki' else path_parser.pred_mturk_nyt if save_pred: pred_save_fp = join(save_dir, config_loader.meta_model_name) eval_args['save_pred_to'] = pred_save_fp if save_gold: true_save_fp = join(save_dir, 'gold') eval_args['save_true_to'] = true_save_fp eval_res = metrics.metric_eval_for_mturk(**eval_args) cf_mats.append(eval_res['cf_mat_list']) precision_list.extend(eval_res['precision_list']) recall_list.extend(eval_res['recall_list']) cls_f1, avg_f1 = metrics.compute_f1_with_confusion_mats(cf_mats) eval_log_info = { 'ph': 'Test', 'avg_f1': avg_f1, 'cls_f1': cls_f1, } res_str = build_res_str(stage=None, use_loss=False, use_exam_f1=False, use_hamming=False) logger.info(res_str.format(**eval_log_info))
def test_model_word_mturk_with_checkpoints(model, matching_mode=None, corpus='wiki', save_pred=False, save_gold=False, n_iter=None, restore=False): if corpus == 'wiki': save_dir = path_parser.pred_mturk_wiki elif corpus == 'nyt': if lang != 'en': raise ValueError('Set lang to en when NYT corpus is used') save_dir = path_parser.pred_mturk_nyt else: raise ValueError('Invalid corpus: {}'.format(corpus)) if config_loader.placement == 'auto': model = nn.DataParallel(model, device_ids=config_loader.device) if config_loader.placement in ('auto', 'single'): model.cuda() logger.info('START: model testing on [MTURK WORDS]') checkpoint = join(path_parser.model_save, config_loader.model_name) if restore: checkpoint = join(checkpoint, 'resume') filter_keys = None if config_loader.reset_size_for_test and not config_loader.set_sep_des_size: logger.info('Filter DES pretrained paras...') filter_keys = [ 'module.word_det.des_ids', 'module.word_det.des_sent_mask', 'module.word_det.des_word_mask' ] load_checkpoint(checkpoint=checkpoint, model=model, n_iter=n_iter, filter_keys=filter_keys) grain = 'word' dataset_type = '-'.join(('mturk', corpus, grain)) data_loader = pipe.DomDetDataLoader(dataset_type=dataset_type) model.eval() c = copy.deepcopy pred_grain = get_model_pred_grain() p_list = list() r_list = list() y_true_sents_list = list() n_sents_list = list() for batch_idx, batch in enumerate(data_loader): feed_dict = c(batch) del feed_dict['sent_labels'] del feed_dict['word_labels'] del feed_dict['n_sents'] del feed_dict['n_words'] for (k, v) in feed_dict.items(): feed_dict[k] = Variable(v, requires_grad=False, volatile=True) # fix ids and masks if pred_grain == 'doc': _, doc_scores = model(**feed_dict) hyp_scores = doc_scores.data.cpu().numpy() elif pred_grain == 'sent': _, _, sent_scores = model(**feed_dict) hyp_scores = sent_scores.data.cpu().numpy() elif pred_grain == 'word': feed_dict['return_sent_attn'] = True feed_dict['return_word_attn'] = True _, _, _, word_scores, _, word_attn = model(**feed_dict) hyp_scores = word_scores.data.cpu().numpy( ) # n_batch * n_sents * n_words * n_doms else: raise ValueError('Invalid prediction grain: {}'.format(pred_grain)) # turn vars to numpy arrays y_true_sents = batch['sent_labels'].cpu().numpy( ) # d_batch * max_n_sents * n_doms y_true_words = batch['word_labels'].cpu().numpy( ) # d_batch * max_n_sents * max_n_words n_sents = batch['n_sents'].cpu().numpy() n_words = batch['n_words'].cpu().numpy() logger.info('batch_size: {0}'.format(y_true_words.shape[0])) eval_args = { 'hyp_scores': hyp_scores, 'y_true_sents': y_true_sents, 'y_true_words': y_true_words, 'n_sents': n_sents, 'n_words': n_words, 'pred_grain': pred_grain, 'max_alter': True, 'matching_mode': matching_mode, } if save_pred: fn = '_'.join((grain, config_loader.meta_model_name)) pred_save_fp = join(save_dir, fn) eval_args['save_pred_to'] = pred_save_fp if save_gold: fn = '_'.join((grain, 'gold')) true_save_fp = join(save_dir, fn) eval_args['save_true_to'] = true_save_fp eval_res = metrics_word_eval_binary.metric_eval_for_mturk_words_with_ir( **eval_args) p_list.extend(eval_res['p_list']) r_list.extend(eval_res['r_list']) y_true_sents_list.append(y_true_sents) n_sents_list.append(n_sents) exam_f1 = metrics.compute_example_based_f1(p_list, r_list) logger.info('word-eval. exam_f1: {0:6f}'.format(exam_f1)) report_dom_specific_f1(p_list, r_list, y_true_sents_list[0], n_sents_list[0])
def train_model_with_checkpoints(model, restore=False, batch_log=True, batch_eval=True): logger.info('Config: {0}'.format(config_loader.model_name)) max_n_iter = config_model['n_batches'] checkpoint = join(path_parser.model_save, config_loader.model_name) performance_save_fp = join(path_parser.performances, '{0}.txt'.format(config_loader.model_name)) data_loader = pipe.DomDetDataLoader(dataset_type='train') optimizer = make_opt(model) if config_loader.placement == 'auto': model = nn.DataParallel(model, device_ids=config_loader.device) if config_loader.placement in ('auto', 'single'): model.cuda() if restore: global_n_iter = load_checkpoint(checkpoint, model=model, optimizer=optimizer, no_iter_strategy='last') global_n_iter -= 1 # backward compatible logger.info( 'MODE: restore a pre-trained model and resume training from {}'. format(global_n_iter)) checkpoint = join(checkpoint, 'resume') max_n_iter += global_n_iter else: logger.info('MODE: create a new model') global_n_iter = 0 train_skip_iter = 50 # 50 eval_skip_iter = 200 checkpoint_dict = dict() # {n_batches: f1} batch_loss = 0.0 for epoch_idx in range(config_model['n_epochs']): for batch_idx, batch in enumerate(data_loader): model.train(mode=True) global_n_iter += 1 feed_dict = copy.deepcopy(batch) for (k, v) in feed_dict.items(): feed_dict[k] = Variable( v, requires_grad=False) # fix ids and masks loss = model(**feed_dict)[0] if config_loader.placement == 'auto': loss = loss.mean() # gather loss from multiple gpus batch_loss += loss.data[0] optimizer.zero_grad() # clear history gradients loss.backward(retain_graph=True) if config_model['grad_clip'] is not None: nn.utils.clip_grad_norm(model.parameters(), config_model['grad_clip']) optimizer.step() if batch_log and global_n_iter % train_skip_iter == 0: check_grads(model) logger.info('iter: {0}, loss: {1:.6f}'.format( global_n_iter, loss.data[0])) if batch_eval and global_n_iter % eval_skip_iter == 0: eval_log_info = eval_model(model=model, phase='dev') eval_log_info['n_iter'] = global_n_iter # res_str = build_res_str(stage='n_iter') # logger.info(res_str.format(**eval_log_info)) mean_batch_loss = batch_loss / eval_skip_iter eval_loss = eval_log_info['loss'] f1 = eval_log_info['avg_f1'] checkpoint_dict, update, is_best = update_checkpoint_dict( checkpoint_dict, k=global_n_iter, v=f1) perf_rec = '{0}\t{1:.6f}\t{2:.6f}\t{3:.6f}\n'.format( global_n_iter, mean_batch_loss, eval_loss, f1) state = { 'n_iters': global_n_iter + 1, 'state_dict': model.state_dict(), 'optimizer_dict': optimizer.state_dict() }, if update: save_checkpoint(state, checkpoint, global_n_iter, is_best) clean_outdated_checkpoints(checkpoint, checkpoint_dict) with open(performance_save_fp, 'a', encoding='utf-8') as f: f.write(perf_rec) logger.info(perf_rec.strip('\n')) batch_loss = 0.0 if global_n_iter == max_n_iter: break if global_n_iter == max_n_iter: logger.info( 'finished expected training: {0} batches!'.format(max_n_iter)) break