def get_sentence_pair(top_k, d_list, p_level_results_list, is_training, debug_mode=False): # t_db_cursor = wiki_db_tool.get_cursor(config.WHOLE_PROCESS_FOR_RINDEX_DB) # # dev_list = common.load_json(config.DEV_FULLWIKI_FILE) # dev_list = common.load_json(config.DEV_FULLWIKI_FILE) dev_list = d_list # cur_dev_eval_results_list = common.load_jsonl( # config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_document_level/2019_4_17/dev_p_level_bert_v1_results.jsonl") cur_dev_eval_results_list = p_level_results_list if debug_mode: dev_list = dev_list[:100] id_set = set([item['_id'] for item in dev_list]) cur_dev_eval_results_list = [item for item in p_level_results_list if item['qid'] in id_set] dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id') copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict(cur_dev_eval_results_list, copied_dev_o_dict, 'qid', 'fid', check=True) cur_results_dict_top2 = select_top_k_and_to_results_dict(copied_dev_o_dict, top_k=top_k, filter_value=None) # print(cur_results_dict_top2) fitems = build_sentence_forward_item(cur_results_dict_top2, dev_list, is_training=is_training, db_cursor=t_db_cursor) return fitems
def inspect_upstream_eval(): dev_list = common.load_json(config.DEV_FULLWIKI_FILE) dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id') dev_eval_results_list = common.load_jsonl( config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_sentence_level/04-19-02:17:11_hotpot_v0_slevel_retri_(doc_top_k:2)/i(12000)|e(2)|v02_f1(0.7153646038858843)|v02_recall(0.7114645831323757)|v05_f1(0.7153646038858843)|v05_recall(0.7114645831323757)|seed(12)/dev_s_level_bert_v1_results.jsonl" ) copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( dev_eval_results_list, copied_dev_o_dict, 'qid', 'fid', check=True) # 0.5 # cur_results_dict_v05 = select_top_k_and_to_results_dict(copied_dev_o_dict, top_k=5, # score_field_name='prob', # filter_value=0.5, # result_field='sp') cur_results_dict_v02 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.2, result_field='sp') # _, metrics_v5 = ext_hotpot_eval.eval(cur_results_dict_v05, dev_list, verbose=False) _, metrics_v2 = ext_hotpot_eval.eval(cur_results_dict_v02, dev_list, verbose=False) v02_sp_f1 = metrics_v2['sp_f1'] v02_sp_recall = metrics_v2['sp_recall'] v02_sp_prec = metrics_v2['sp_prec'] v05_sp_f1 = metrics_v5['sp_f1'] v05_sp_recall = metrics_v5['sp_recall'] v05_sp_prec = metrics_v5['sp_prec'] logging_item = { 'label': 'ema', 'v02': metrics_v2, # 'v05': metrics_v5, } print(logging_item)
def eval_hotpot_s(): cur_dev_eval_results_list_out = common.load_jsonl( config.PRO_ROOT / "data/p_hotpotqa/hotpot_p_level_effects/hotpot_s_level_dev_results_top_k_doc_100.jsonl" ) dev_list = common.load_json(config.DEV_FULLWIKI_FILE) dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id') copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( cur_dev_eval_results_list_out, copied_dev_o_dict, 'qid', 'fid', check=True) # 0.5 cur_results_dict_v05 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.5, result_field='sp') # cur_results_dict_v02 = select_top_k_and_to_results_dict(copied_dev_o_dict, top_k=5, # score_field_name='prob', # filter_value=0.2, # result_field='sp') _, metrics_v5 = ext_hotpot_eval.eval(cur_results_dict_v05, dev_list, verbose=False) # _, metrics_v2 = ext_hotpot_eval.eval(cur_results_dict_v02, dev_list, verbose=False) logging_item = { # 'v02': metrics_v2, 'v05': metrics_v5, } print(logging_item) f1 = metrics_v5['sp_f1'] em = metrics_v5['sp_em'] pr = metrics_v5['sp_prec'] rec = metrics_v5['sp_recall'] print(em, pr, rec, f1)
def get_qa_item_with_upstream_sentence(d_list, sentence_level_results, is_training, tokenizer: BertTokenizer, max_context_length, max_query_length, doc_stride=128, debug_mode=False, top_k=5, filter_value=0.2): t_db_cursor = wiki_db_tool.get_cursor(config.WHOLE_PROCESS_FOR_RINDEX_DB) if debug_mode: d_list = d_list[:100] id_set = set([item['_id'] for item in d_list]) sentence_level_results = [ item for item in sentence_level_results if item['qid'] in id_set ] d_o_dict = list_dict_data_tool.list_to_dict(d_list, '_id') copied_d_o_dict = copy.deepcopy(d_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( sentence_level_results, copied_d_o_dict, 'qid', 'fid', check=True) cur_results_dict = select_top_k_and_to_results_dict( copied_d_o_dict, top_k=top_k, score_field_name='prob', filter_value=filter_value, result_field='sp') forward_example_items = build_qa_forword_item(cur_results_dict, d_list, is_training, t_db_cursor) forward_example_items = format_convert(forward_example_items, is_training) fitems_dict, read_fitems_list = span_preprocess_tool.eitems_to_fitems( forward_example_items, tokenizer, is_training, max_context_length, max_query_length, doc_stride, False) return fitems_dict, read_fitems_list, cur_results_dict['sp']
def eval_model_for_downstream_ablation(model_saved_path, doc_top_k=2, tag='dev'): print(f"Run doc_top_k:{doc_top_k}") bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert' seed = 12 torch.manual_seed(seed) bert_model_name = 'bert-base-uncased' # lazy = False lazy = True # forward_size = 256 forward_size = 256 # batch_size = 64 batch_size = 128 do_lower_case = True document_top_k = doc_top_k debug_mode = False # est_datasize = 900_000 num_class = 1 # num_train_optimization_steps device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_num = 0 if torch.cuda.is_available() else -1 n_gpu = torch.cuda.device_count() unk_token_num = {'tokens': 1} # work around for initiating vocabulary. vocab = ExVocabulary(unk_token_num=unk_token_num) vocab.add_token_to_namespace("false", namespace="labels") # 0 vocab.add_token_to_namespace("true", namespace="labels") # 1 vocab.add_token_to_namespace("hidden", namespace="labels") vocab.change_token_with_index_to_namespace("hidden", -2, namespace='labels') # Load Dataset train_list = common.load_json(config.TRAIN_FILE) dev_list = common.load_json(config.DEV_FULLWIKI_FILE) test_list = common.load_json(config.TEST_FULLWIKI_FILE) # Load train eval results list # cur_train_eval_results_list = common.load_jsonl( # config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/" # "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/train_p_level_bert_v1_results.jsonl") cur_dev_eval_results_list = common.load_jsonl( config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/" "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/dev_p_level_bert_v1_results.jsonl" ) # cur_test_eval_results_list = common.load_jsonl( # config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/" # "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/test_p_level_bert_v1_results.jsonl") # if tag == 'train': # train_fitems = get_sentence_pair(document_top_k, train_list, cur_train_eval_results_list, is_training=True, # debug_mode=debug_mode) if tag == 'dev': dev_fitems = get_sentence_pair(document_top_k, dev_list, cur_dev_eval_results_list, is_training=False, debug_mode=debug_mode) # elif tag == 'test': # test_fitems = get_sentence_pair(document_top_k, test_list, cur_test_eval_results_list, is_training=False, # debug_mode=debug_mode) if debug_mode: eval_frequency = 2 # dev_list = dev_list[:10] # dev_fitems_list = dev_fitems_list[:296] # train_fitems_list = train_fitems_list[:300] # print(dev_list[-1]['_id']) # exit(0) dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id') train_o_dict = list_dict_data_tool.list_to_dict(train_list, '_id') bert_tokenizer = BertTokenizer.from_pretrained( bert_model_name, do_lower_case=do_lower_case, cache_dir=bert_pretrain_path) bert_cs_reader = BertContentSelectionReader( bert_tokenizer, lazy, is_paired=True, example_filter=lambda x: len(x['context']) == 0, max_l=128, element_fieldname='element') bert_encoder = BertModel.from_pretrained(bert_model_name, cache_dir=bert_pretrain_path) model = BertMultiLayerSeqClassification(bert_encoder, num_labels=num_class, num_of_pooling_layer=1, act_type='tanh', use_pretrained_pooler=True, use_sigmoid=True) model.load_state_dict(torch.load(model_saved_path)) model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # if tag == 'train': train_instance = bert_cs_reader.read(train_fitems) elif tag == 'dev': dev_instances = bert_cs_reader.read(dev_fitems) elif tag == 'test': test_instances = bert_cs_reader.read(test_fitems) biterator = BasicIterator(batch_size=forward_size) biterator.index_with(vocab) if tag == 'train': train_iter = biterator(train_instance, num_epochs=1, shuffle=False) print(len(train_fitems)) elif tag == 'dev': dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False) print(len(dev_fitems)) elif tag == 'test': test_iter = biterator(test_instances, num_epochs=1, shuffle=False) print(len(test_fitems)) print("Forward size:", forward_size) if tag == 'train': cur_train_eval_results_list_out = eval_model(model, train_iter, device_num, with_probs=True, show_progress=True) common.save_jsonl( cur_train_eval_results_list_out, config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_sentence_level/04-19-02:17:11_hotpot_v0_slevel_retri_(doc_top_k:2)/i(12000)|e(2)|v02_f1(0.7153646038858843)|v02_recall(0.7114645831323757)|v05_f1(0.7153646038858843)|v05_recall(0.7114645831323757)|seed(12)/train_s_level_bert_v1_results.jsonl" ) elif tag == 'dev': cur_dev_eval_results_list_out = eval_model(model, dev_iter, device_num, with_probs=True, show_progress=True) common.save_jsonl( cur_dev_eval_results_list_out, f"hotpot_s_level_{tag}_results_top_k_doc_{document_top_k}.jsonl") elif tag == 'test': cur_test_eval_results_list_out = eval_model(model, test_iter, device_num, with_probs=True, show_progress=True) common.save_jsonl( cur_test_eval_results_list_out, config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_sentence_level/04-19-02:17:11_hotpot_v0_slevel_retri_(doc_top_k:2)/i(12000)|e(2)|v02_f1(0.7153646038858843)|v02_recall(0.7114645831323757)|v05_f1(0.7153646038858843)|v05_recall(0.7114645831323757)|seed(12)/test_s_level_bert_v1_results.jsonl" ) if tag == 'train' or tag == 'test': exit(0) copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( cur_dev_eval_results_list_out, copied_dev_o_dict, 'qid', 'fid', check=True) # 0.5 cur_results_dict_v05 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.5, result_field='sp') cur_results_dict_v02 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.2, result_field='sp') _, metrics_v5 = ext_hotpot_eval.eval(cur_results_dict_v05, dev_list, verbose=False) _, metrics_v2 = ext_hotpot_eval.eval(cur_results_dict_v02, dev_list, verbose=False) logging_item = { 'v02': metrics_v2, 'v05': metrics_v5, } print(logging_item) f1 = metrics_v5['sp_f1'] em = metrics_v5['sp_em'] pr = metrics_v5['sp_prec'] rec = metrics_v5['sp_recall'] common.save_json( logging_item, f"top_k_doc:{document_top_k}_em:{em}_pr:{pr}_rec:{rec}_f1:{f1}")
def model_go(): seed = 12 torch.manual_seed(seed) # bert_model_name = 'bert-large-uncased' bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert' bert_model_name = 'bert-base-uncased' lazy = False # lazy = True forward_size = 128 # batch_size = 64 batch_size = 128 gradient_accumulate_step = int(batch_size / forward_size) warmup_proportion = 0.1 learning_rate = 5e-5 num_train_epochs = 5 eval_frequency = 2000 pos_ratio = 0.2 do_lower_case = True document_top_k = 2 experiment_name = f'hotpot_v0_slevel_retri_(doc_top_k:{document_top_k})' debug_mode = False do_ema = True # est_datasize = 900_000 num_class = 1 # num_train_optimization_steps device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_num = 0 if torch.cuda.is_available() else -1 n_gpu = torch.cuda.device_count() unk_token_num = {'tokens': 1} # work around for initiating vocabulary. vocab = ExVocabulary(unk_token_num=unk_token_num) vocab.add_token_to_namespace("false", namespace="labels") # 0 vocab.add_token_to_namespace("true", namespace="labels") # 1 vocab.add_token_to_namespace("hidden", namespace="labels") vocab.change_token_with_index_to_namespace("hidden", -2, namespace='labels') # Load Dataset train_list = common.load_json(config.TRAIN_FILE) dev_list = common.load_json(config.DEV_FULLWIKI_FILE) # train_fitems = sentence_level_sampler.get_train_sentence_pair(document_top_k, True, debug_mode) # dev_fitems = sentence_level_sampler.get_dev_sentence_pair(document_top_k, False, debug_mode) # Load train eval results list cur_train_eval_results_list = common.load_jsonl( config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/" "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/train_p_level_bert_v1_results.jsonl" ) cur_dev_eval_results_list = common.load_jsonl( config.PRO_ROOT / "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/" "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/dev_p_level_bert_v1_results.jsonl" ) train_fitems = get_sentence_pair(document_top_k, train_list, cur_train_eval_results_list, is_training=True, debug_mode=debug_mode) dev_fitems = get_sentence_pair(document_top_k, dev_list, cur_dev_eval_results_list, is_training=False, debug_mode=debug_mode) if debug_mode: dev_list = dev_list[:100] eval_frequency = 2 # print(dev_list[-1]['_id']) # exit(0) # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio) est_datasize = len(train_fitems) dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id') # print(dev_o_dict) bert_tokenizer = BertTokenizer.from_pretrained( bert_model_name, do_lower_case=do_lower_case, cache_dir=bert_pretrain_path) bert_cs_reader = BertContentSelectionReader( bert_tokenizer, lazy, is_paired=True, example_filter=lambda x: len(x['context']) == 0, max_l=128, element_fieldname='element') bert_encoder = BertModel.from_pretrained(bert_model_name, cache_dir=bert_pretrain_path) model = BertMultiLayerSeqClassification(bert_encoder, num_labels=num_class, num_of_pooling_layer=1, act_type='tanh', use_pretrained_pooler=True, use_sigmoid=True) ema = None if do_ema: ema = EMA(model, model.named_parameters(), device_num=1) model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \ num_train_epochs if debug_mode: num_train_optimization_steps = 100 print("Estimated training size", est_datasize) print("Number of optimization steps:", num_train_optimization_steps) optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate, warmup=warmup_proportion, t_total=num_train_optimization_steps) dev_instances = bert_cs_reader.read(dev_fitems) biterator = BasicIterator(batch_size=forward_size) biterator.index_with(vocab) forbackward_step = 0 update_step = 0 logging_agent = save_tool.ScoreLogger({}) # # # Create Log File file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}") # Save the source code. script_name = os.path.basename(__file__) with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it: out_f.write(it.read()) out_f.flush() # # # Log File end for epoch_i in range(num_train_epochs): print("Epoch:", epoch_i) # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio) random.shuffle(train_fitems) train_instance = bert_cs_reader.read(train_fitems) train_iter = biterator(train_instance, num_epochs=1, shuffle=True) for batch in tqdm(train_iter): model.train() batch = move_to_device(batch, device_num) paired_sequence = batch['paired_sequence'] paired_segments_ids = batch['paired_segments_ids'] labels_ids = batch['label'] att_mask, _ = torch_util.get_length_and_mask(paired_sequence) s1_span = batch['bert_s1_span'] s2_span = batch['bert_s2_span'] loss = model( paired_sequence, token_type_ids=paired_segments_ids, attention_mask=att_mask, mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN, labels=labels_ids) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if gradient_accumulate_step > 1: loss = loss / gradient_accumulate_step loss.backward() forbackward_step += 1 if forbackward_step % gradient_accumulate_step == 0: optimizer.step() if ema is not None and do_ema: updated_model = model.module if hasattr( model, 'module') else model ema(updated_model.named_parameters()) optimizer.zero_grad() update_step += 1 if update_step % eval_frequency == 0: print("Update steps:", update_step) dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False) cur_eval_results_list = eval_model(model, dev_iter, device_num, with_probs=True) copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( cur_eval_results_list, copied_dev_o_dict, 'qid', 'fid', check=True) # 0.5 cur_results_dict_v05 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.5, result_field='sp') cur_results_dict_v02 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.2, result_field='sp') _, metrics_v5 = ext_hotpot_eval.eval(cur_results_dict_v05, dev_list, verbose=False) _, metrics_v2 = ext_hotpot_eval.eval(cur_results_dict_v02, dev_list, verbose=False) v02_sp_f1 = metrics_v2['sp_f1'] v02_sp_recall = metrics_v2['sp_recall'] v02_sp_prec = metrics_v2['sp_prec'] v05_sp_f1 = metrics_v5['sp_f1'] v05_sp_recall = metrics_v5['sp_recall'] v05_sp_prec = metrics_v5['sp_prec'] logging_item = { 'v02': metrics_v2, 'v05': metrics_v5, } print(logging_item) # print(logging_item) if not debug_mode: save_file_name = f'i({update_step})|e({epoch_i})' \ f'|v02_f1({v02_sp_f1})|v02_recall({v02_sp_recall})' \ f'|v05_f1({v05_sp_f1})|v05_recall({v05_sp_recall})|seed({seed})' # print(save_file_name) logging_agent.incorporate_results({}, save_file_name, logging_item) logging_agent.logging_to_file( Path(file_path_prefix) / "log.json") model_to_save = model.module if hasattr( model, 'module') else model output_model_file = Path( file_path_prefix) / save_file_name torch.save(model_to_save.state_dict(), str(output_model_file)) if do_ema and ema is not None: ema_model = ema.get_inference_model() master_device_num = 1 ema_inference_device_ids = get_ema_gpu_id_list( master_device_num=master_device_num) ema_model = ema_model.to(master_device_num) ema_model = torch.nn.DataParallel( ema_model, device_ids=ema_inference_device_ids) dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False) cur_eval_results_list = eval_model(ema_model, dev_iter, master_device_num, with_probs=True) copied_dev_o_dict = copy.deepcopy(dev_o_dict) list_dict_data_tool.append_subfield_from_list_to_dict( cur_eval_results_list, copied_dev_o_dict, 'qid', 'fid', check=True) # 0.5 cur_results_dict_v05 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.5, result_field='sp') cur_results_dict_v02 = select_top_k_and_to_results_dict( copied_dev_o_dict, top_k=5, score_field_name='prob', filter_value=0.2, result_field='sp') _, metrics_v5 = ext_hotpot_eval.eval( cur_results_dict_v05, dev_list, verbose=False) _, metrics_v2 = ext_hotpot_eval.eval( cur_results_dict_v02, dev_list, verbose=False) v02_sp_f1 = metrics_v2['sp_f1'] v02_sp_recall = metrics_v2['sp_recall'] v02_sp_prec = metrics_v2['sp_prec'] v05_sp_f1 = metrics_v5['sp_f1'] v05_sp_recall = metrics_v5['sp_recall'] v05_sp_prec = metrics_v5['sp_prec'] logging_item = { 'label': 'ema', 'v02': metrics_v2, 'v05': metrics_v5, } print(logging_item) if not debug_mode: save_file_name = f'ema_i({update_step})|e({epoch_i})' \ f'|v02_f1({v02_sp_f1})|v02_recall({v02_sp_recall})' \ f'|v05_f1({v05_sp_f1})|v05_recall({v05_sp_recall})|seed({seed})' # print(save_file_name) logging_agent.incorporate_results({}, save_file_name, logging_item) logging_agent.logging_to_file( Path(file_path_prefix) / "log.json") model_to_save = ema_model.module if hasattr( ema_model, 'module') else ema_model output_model_file = Path( file_path_prefix) / save_file_name torch.save(model_to_save.state_dict(), str(output_model_file))