def train(model_config=None): model_config = (DefaultConfig() if model_config is None else model_config) data = TrainData(model_config) graph = None if model_config.framework == 'transformer': graph = TransformerGraph(data, True, model_config) elif model_config.framework == 'seq2seq': graph = Seq2SeqGraph(data, True, model_config) else: raise NotImplementedError('Unknown Framework.') graph.create_model_multigpu() ckpt_path = None if model_config.warm_start: if model_config.warm_start == 'recent': ckpt_path = find_best_ckpt(model_config) else: ckpt_path = model_config.warm_start var_list = slim.get_variables_to_restore() if ckpt_path is not None: # Handling missing vars by ourselves available_vars = {} reader = tf.train.NewCheckpointReader(ckpt_path) var_dict = {var.op.name: var for var in var_list} for var in var_dict: if 'global_step' in var: continue if 'optimization' in var: continue if reader.has_tensor(var): var_ckpt = reader.get_tensor(var) var_cur = var_dict[var] if any([var_cur.shape[i] != var_ckpt.shape[i] for i in range(len(var_ckpt.shape))]): print('Variable %s missing due to shape.', var) else: available_vars[var] = var_dict[var] else: print('Variable %s missing.', var) partial_restore_ckpt = slim.assign_from_checkpoint_fn( ckpt_path, available_vars, ignore_missing_vars=False, reshape_variables=False) def init_fn(session): # Restore ckpt either from warm start or automatically get when changing optimizer ckpt_path = None if model_config.warm_start: ckpt_path = model_config.warm_start if ckpt_path is not None: if model_config.use_partial_restore: partial_restore_ckpt(session) else: try: graph.saver.restore(session, ckpt_path) except Exception as ex: print('Fully restore failed, use partial restore instead. \n %s' % str(ex)) partial_restore_ckpt(session) print('Warm start with checkpoint %s' % ckpt_path) sv = tf.train.Supervisor(logdir=model_config.logdir, global_step=graph.global_step, saver=graph.saver, init_fn=init_fn, save_model_secs=model_config.save_model_secs) sess = sv.PrepareSession(config=session.get_session_config(model_config)) perplexitys = [] start_time = datetime.now() while True: input_feed = get_graph_train_data( data, graph.objs, model_config) # fetches = [graph.train_op, graph.loss, graph.global_step, # graph.perplexity, graph.ops, graph.attn_dists, graph.targets, graph.cs] # _, loss, step, perplexity, _ops , attn_dists, targets, cs = sess.run(fetches, input_feed) fetches = [graph.train_op, graph.loss, graph.global_step, graph.perplexity, graph.ops, graph.logits] _, loss, step, perplexity, _, logits = sess.run(fetches, input_feed) perplexitys.append(perplexity) if step % model_config.model_print_freq == 0: end_time = datetime.now() time_span = end_time - start_time start_time = end_time print('Perplexity:\t%f at step %d using %s.' % (perplexity, step, time_span)) perplexitys.clear()
def eval(model_config=None, ckpt=None): model_config = (DefaultConfig() if model_config is None else model_config) if not exists(model_config.resultdir): makedirs(model_config.resultdir) print(list_config(model_config)) val_data = ValData(model_config) graph = None if model_config.framework == 'transformer': graph = TransformerGraph(val_data, False, model_config) elif model_config.framework == 'seq2seq': graph = Seq2SeqGraph(val_data, False, model_config) tf.reset_default_graph() graph.create_model_multigpu() ibleus_all = [] perplexitys_all = [] saris_all = [] decode_outputs_all = [] targets = [] targets_raw = [] sentence_simples = [] sentence_complexs = [] sentence_complexs_raw = [] it = val_data.get_data_iter() def init_fn(session): graph.saver.restore(session, ckpt) print('Restore ckpt:%s.' % ckpt) sv = tf.train.Supervisor(init_fn=init_fn) sess = sv.PrepareSession(config=session.get_session_config(model_config)) while True: is_finish = False (input_feed, output_sentence_simple, output_sentence_complex, output_sentence_complex_raw, output_sentence_complex_raw_lines, output_mapper, output_ref_raw_lines, out_effective_batch_size, output_is_end) = get_graph_val_data(graph.objs, model_config, it, val_data) postprocess = PostProcess(model_config, val_data) fetches = { 'decoder_target_list': [obj['decoder_target_list'] for obj in graph.objs], 'loss': graph.loss, 'global_step': graph.global_step } if model_config.replace_unk_by_emb: print("########REPLACING UNKS########") fetches.update({ 'encoder_embs': [obj['encoder_embs'] for obj in graph.objs], 'final_outputs': [obj['final_outputs'] for obj in graph.objs] }) if model_config.replace_unk_by_attn: fetches.update( {'attn_distr': [obj['attn_distr'] for obj in graph.objs]}) results = sess.run(fetches, input_feed) output_target, loss, step = (results['decoder_target_list'], results['loss'], results['global_step']) if model_config.replace_unk_by_emb: print("########REPLACING UNKS########") output_encoder_embs, output_final_outputs = results[ 'encoder_embs'], results['final_outputs'] if model_config.replace_unk_by_attn: attn_distr = results['attn_distr'] batch_perplexity = math.exp(loss) perplexitys_all.append(batch_perplexity) for i, effective_batch_size in enumerate(out_effective_batch_size): is_end = output_is_end[i] exclude_idxs = get_exclude_list(effective_batch_size, model_config.batch_size) sentence_simple = output_sentence_simple[i] sentence_complex = output_sentence_complex[i] sentence_complex_raw = output_sentence_complex_raw[i] sentence_complex_raw_lines = output_sentence_complex_raw_lines[i] mapper = output_mapper[i] ref_raw_lines = output_ref_raw_lines[i] target = output_target[i] if model_config.replace_unk_by_emb: encoder_embs = output_encoder_embs[i] final_outputs = output_final_outputs[i] if exclude_idxs: sentence_complex = exclude_list(sentence_complex, exclude_idxs) sentence_complex_raw = exclude_list(sentence_complex_raw, exclude_idxs) sentence_complex_raw_lines = exclude_list( sentence_complex_raw_lines, exclude_idxs) sentence_simple = exclude_list(sentence_simple, exclude_idxs) target = exclude_list(target, exclude_idxs) mapper = exclude_list(mapper, exclude_idxs) for ref_i in range(model_config.num_refs): ref_raw_lines[ref_i] = exclude_list( ref_raw_lines[ref_i], exclude_idxs) target = decode(target, val_data.vocab_simple, model_config.subword_vocab_size > 0) target_raw = target sentence_complex_marker = [[ val_data.vocab_simple.encode(w) == val_data.vocab_simple.encode(constant.SYMBOL_UNK) for w in sent ] for sent in sentence_complex_raw] if model_config.replace_unk_by_attn: target_raw = postprocess.replace_unk_by_attn( sentence_complex_raw, attn_distr[0], target_raw) elif model_config.replace_unk_by_emb: target_raw = postprocess.replace_unk_by_emb( sentence_complex_raw, encoder_embs, final_outputs, target_raw, sentence_complex_marker) elif model_config.replace_unk_by_cnt: target_raw = postprocess.replace_unk_by_cnt( sentence_complex_raw, target_raw) if model_config.replace_ner: target_raw = postprocess.replace_ner(target_raw, mapper) target_raw = postprocess.replace_others(target_raw) sentence_simple = decode(sentence_simple, val_data.vocab_simple, model_config.subword_vocab_size > 0) sentence_complex = decode(sentence_complex, val_data.vocab_complex, model_config.subword_vocab_size > 0) # Replace UNK for sentence_complex_raw and ref_raw # Note that sentence_complex_raw_lines and ref_raw_lines are original file lines sentence_complex_raw = postprocess.replace_ner( sentence_complex_raw, mapper) sentence_complex_raw = truncate_sents(sentence_complex_raw) # Truncate decode results target = truncate_sents(target) target_raw = truncate_sents(target_raw) sentence_simple = truncate_sents(sentence_simple) sentence_complex = truncate_sents(sentence_complex) targets.extend(target) targets_raw.extend(target_raw) sentence_simples.extend(sentence_simple) sentence_complexs.extend(sentence_complex) sentence_complexs_raw.extend(sentence_complex_raw) ibleus = [] saris = [] fkgls = [] for batch_i in range(effective_batch_size): # Compute iBLEU try: batch_ibleu = sentence_bleu([sentence_simple[batch_i]], target[batch_i]) except Exception as e: print('Bleu error:\t' + str(e) + '\n' + str(target[batch_i]) + '\n') batch_ibleu = 0 ibleus_all.append(batch_ibleu) ibleus.append(batch_ibleu) # Compute SARI batch_sari = 0 if model_config.num_refs > 0: rsents = [] for ref_i in range(model_config.num_refs): rsents.append(ref_raw_lines[ref_i][batch_i]) try: batch_sari = SARIsent( sentence_complex_raw_lines[batch_i], ' '.join(target_raw[batch_i]), rsents) except: print('sari error: %s \n %s \n %s. \n' % (sentence_complex_raw_lines[batch_i], ' '.join( target_raw[batch_i]), rsents)) saris.append(batch_sari) saris_all.append(batch_sari) # Compute FKGL target_text = ' '.join(target_raw[batch_i]) batch_fkgl = 0 if len(target_text) > 0: batch_fkgl = get_fkgl(' '.join(target_raw[batch_i])) fkgls.append(batch_fkgl) # target_output = decode_to_output(target, sentence_simple, sentence_complex, # effective_batch_size, ibleus, target_raw, sentence_complex_raw, # saris, fkgls) target_output = decode_to_output(target, sentence_simple, sentence_complex, effective_batch_size, ibleus, target_raw, sentence_complex_raw, saris, fkgls, ref_raw_lines, model_config) decode_outputs_all.append(target_output) if is_end: is_finish = True break if is_finish: break ibleu = np.mean(ibleus_all) perplexity = np.mean(perplexitys_all) sari = np.mean(saris_all) # Compute FKGL in Corpus level fkgl = CorpusFKGL(model_config).get_fkgl_from_joshua(step, targets_raw) print('Current iBLEU: \t%f' % ibleu) print('Current SARI: \t%f' % sari) print('Current FKGL: \t%f' % fkgl) print('Current perplexity: \t%f' % perplexity) print('Current eval done!') # MtEval Result mteval = MtEval_BLEU(model_config) # MtEval Result - Decode # bleu_oi_decode = mteval.get_bleu_from_decoderesult(step, sentence_complexs, sentence_simples, targets) # bleu_or_decode = bleu_oi_decode # if model_config.num_refs > 0: # path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_references # #Decode evaluation must be lowercase because the processed files are all lowercased # bleu_or_decode = mteval.get_bleu_from_decoderesult_multirefs(step, path_ref, targets, # lowercase=True) # if model_config.num_refs > 0: # bleu_decode = 0.9 * bleu_or_decode + 0.1 * bleu_oi_decode # else: # bleu_decode = bleu_oi_decode # print('Current Mteval iBLEU decode: \t%f' % bleu_decode) # MtEval Result - raw bleu_oi_raw = mteval.get_bleu_from_rawresult(step, targets_raw) bleu_or_raw = bleu_oi_raw if model_config.num_refs > 0: path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file_references bleu_or_raw = mteval.get_bleu_from_decoderesult_multirefs( step, path_ref, targets_raw, lowercase=model_config.lower_case) if model_config.num_refs > 0: bleu_raw = 0.9 * bleu_or_raw + 0.1 * bleu_oi_raw else: bleu_raw = bleu_oi_raw print('Current Mteval iBLEU raw: \t%f' % bleu_raw) bleu_joshua = mteval.get_bleu_from_joshua( step, model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file, model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file_references, targets_raw) # Use corpus-level sari corpus_sari = CorpusSARI(model_config) sari_joshua = corpus_sari.get_sari_from_joshua( step, model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file, model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file_references, model_config.val_dataset_complex_rawlines_file, target_raw) decimal_cnt = 5 format = "%." + str(decimal_cnt) + "f" bleu_raw = format % bleu_raw bleu_oi_raw = format % bleu_oi_raw bleu_or_raw = format % bleu_or_raw # bleu_decode = format % bleu_decode # bleu_oi_decode = format % bleu_oi_decode # bleu_or_decode = format % bleu_or_decode ibleu = format % ibleu bleu_joshua = format % bleu_joshua sari_joshua = format % sari_joshua fkgl = format % fkgl perplexity = format % perplexity content = '\n'.join([ 'bleu_raw\t' + str(bleu_raw), 'bleu_oi_raw\t' + str(bleu_oi_raw), 'bleu_or_raw\t' + str(bleu_or_raw), # 'bleu_decode\t' + str(bleu_decode), # 'bleu_oi_decode\t' + str(bleu_oi_decode), # 'bleu_or_decode\t' + str(bleu_or_decode), 'ibleu\t' + str(ibleu), 'bleu_joshua\t' + str(bleu_joshua), 'sari\t' + str(sari_joshua), 'fkgl\t' + str(fkgl) ]) # Output Result f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' + str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' + str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' + str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' + str(sari_joshua) + '-fkgl' + str(fkgl)), 'w', encoding='utf-8') f.write(content) f.close() f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' + str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' + str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' + str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' + str(sari_joshua) + '-fkgl' + str(fkgl) + '.result'), 'w', encoding='utf-8') f.write('\n'.join(decode_outputs_all)) f.close() return sari_joshua
def train(model_config=None): model_config = (DefaultConfig() if model_config is None else model_config) if model_config.fetch_mode == 'tf_example_dataset': data = TfExampleTrainDataset(model_config) else: data = TrainData(model_config) if model_config.framework == 'transformer': graph = TransformerGraph(data, True, model_config) elif model_config.framework == 'seq2seq': graph = Seq2SeqGraph(data, True, model_config) else: raise NotImplementedError('Unknown Framework.') graph.create_model_multigpu() ckpt_path = None if model_config.warm_start: ckpt_path = model_config.warm_start var_list = slim.get_variables_to_restore() if ckpt_path is not None: # Handling missing vars by ourselves available_vars = {} reader = tf.train.NewCheckpointReader(ckpt_path) var_dict = {var.op.name: var for var in var_list} for var in var_dict: if 'global_step' in var and 'optim' not in model_config.warm_config: print('Ignore var:', var) continue if 'optimization' in var and 'optim' not in model_config.warm_config: print('Ignore var:', var) continue if reader.has_tensor(var): var_ckpt = reader.get_tensor(var) var_cur = var_dict[var] if any([ var_cur.shape[i] != var_ckpt.shape[i] for i in range(len(var_ckpt.shape)) ]): print('Variable missing due to shape.', var) else: available_vars[var] = var_dict[var] else: print('Variable missing:', var) partial_restore_ckpt = slim.assign_from_checkpoint_fn( ckpt_path, available_vars, ignore_missing_vars=False, reshape_variables=False) if model_config.bert_mode: bert_restore_ckpt = utils.restore_bert(ckpt=model_config.bert_ckpt) if 'direct' in model_config.memory: bert_direct_restore_ckpt = utils.restore_bert( ckpt=model_config.bert_ckpt, model='direct/') sess = tf.train.MonitoredTrainingSession( checkpoint_dir=model_config.logdir, save_checkpoint_secs=model_config.save_model_secs, config=session.get_session_config(model_config), hooks=[ tf.train.CheckpointSaverHook( model_config.logdir, save_secs=model_config.save_model_secs, saver=graph.saver) ], save_summaries_steps=None, save_summaries_secs=None, # Disable tf.summary ) if checkpoint.is_fresh_run( model_config.logdir) and 'init' in model_config.bert_mode: if model_config.bert_mode: if 'direct' in model_config.memory: bert_direct_restore_ckpt(sess) # else: bert_restore_ckpt(sess) print('BERT init') if checkpoint.is_fresh_run(model_config.logdir): if ckpt_path is not None: partial_restore_ckpt(sess) print('Restore from %s' % ckpt_path) perplexitys = [] start_time = datetime.now() # Intialize tf example dataset reader if model_config.fetch_mode == 'tf_example_dataset': if model_config.dmode == 'listalter': assert type(data.training_init_op) == list for init_op in data.training_init_op: sess.run(init_op) else: sess.run(data.training_init_op) print('Init dataset interator.') if model_config.dmode == 'alter': sess.run(data.training_init_op2) print('Init dataset2 interator.') # with tf.contrib.tfprof.ProfileContext('/zfs1/hdaqing/saz31/text_simplification_0924/bertbaseal2_ls/profile') as pctx: while True: fetches = [ graph.train_op, graph.loss, graph.global_step, graph.perplexity, graph.ops, graph.increment_global_step, graph.loss_style ] if model_config.fetch_mode: _, loss, step, perplexity, _, _, loss_style = sess.run(fetches) else: input_feed = get_graph_train_data(data, graph.objs, model_config) _, loss, step, perplexity, _, _ = sess.run(fetches, input_feed) perplexitys.append(perplexity) if step % model_config.model_print_freq == 0: end_time = datetime.now() time_span = end_time - start_time start_time = end_time print('Perplexity:\t%f at step %d using %s.' % (perplexity, step, time_span)) if 'pred' in model_config.tune_mode: print('Loss:%s\tLoss_tyle:%s' % (loss, loss_style)) perplexitys.clear() if step / model_config.model_print_freq == 1: print_cpu_usage() print_cpu_memory() print_gpu_memory() #if step % (100 * model_config.model_print_freq) == 0: # graph.saver.save(sess, join(model_config.logdir, 'bk.ckpt-', step)) if model_config.model_eval_freq > 0 and step % model_config.model_eval_freq == 0: if args.mode == 'dress': from model.model_config import WikiDressLargeDefault, WikiDressLargeEvalDefault, \ WikiDressLargeTestDefault model_config = WikiDressLargeDefault() ckpt = get_ckpt(model_config.modeldir, model_config.logdir) vconfig = WikiDressLargeEvalDefault() best_sari = get_best_sari(vconfig.resultdir) sari_point = eval(vconfig, ckpt) eval(WikiDressLargeTestDefault(), ckpt) if args.memory is not None and 'rule' in args.memory: for rcand in [15, 30, 50]: vconfig.max_cand_rules = rcand vconfig.resultdir = get_path( '../' + vconfig.output_folder + '/result/eightref_val_cand' + str(rcand), vconfig.environment) eval(vconfig, ckpt) print( '=====================Current Best SARI:%s=====================' % best_sari) if float(sari_point) < best_sari: remove(ckpt + '.index') remove(ckpt + '.meta') remove(ckpt + '.data-00000-of-00001') print('remove ckpt:%s' % ckpt) else: for file in listdir(model_config.modeldir): step = ckpt[ckpt.rindex('model.ckpt-') + len('model.ckpt-'):-1] if step not in file: remove(model_config.modeldir + file) print('Get Best Model, remove ckpt except:%s.' % ckpt)