Ejemplo n.º 1
0
def main(_):
  if os.path.exists(config.forward_log_path) and config.mode=='forward':
    os.system('rm '+config.forward_log_path)
  if os.path.exists(config.backward_log_path) and config.mode=='backward':
    os.system('rm '+config.backward_log_path)
  if os.path.exists(config.use_output_path):
    os.system('rm '+config.use_output_path)
  for item in config.record_time:
    if os.path.exists(config.use_output_path+str(item)):
      os.system('rm '+config.use_output_path+str(item))
  if os.path.exists(config.use_log_path):
    os.system('rm '+config.use_log_path)
  if config.mode=='forward' or config.mode=='use':
    with tf.name_scope("forward_train"):
      with tf.variable_scope("forward", reuse=None):
        m_forward = PTBModel(is_training=True)
    with tf.name_scope("forward_test"):
      with tf.variable_scope("forward", reuse=True):
        mtest_forward = PTBModel(is_training=False)
    var=tf.trainable_variables()
    var_forward=[x for x in var if x.name.startswith('forward')]
    saver_forward=tf.train.Saver(var_forward, max_to_keep=1)
  if config.mode=='backward' or config.mode=='use':
    with tf.name_scope("backward_train"):
      with tf.variable_scope("backward", reuse=None):
        m_backward = PTBModel(is_training=True)
    with tf.name_scope("backward_test"):
      with tf.variable_scope("backward", reuse=True):
        mtest_backward = PTBModel(is_training=False)
    var=tf.trainable_variables()
    var_backward=[x for x in var if x.name.startswith('backward')]
    saver_backward=tf.train.Saver(var_backward, max_to_keep=1)
    
  init = tf.global_variables_initializer()
  

  with tf.Session() as session:
    session.run(init)
    if config.mode=='forward':
      train_data, test_data = reader.read_data(config.data_path, config.num_steps)
      test_mean_old=15.0
      
      for epoch in range(config.max_epoch):
        train_ppl_list=[]
        test_ppl_list=[]
        for i in range(train_data.length//config.batch_size):
          input, sequence_length, target=train_data(m_forward.batch_size, i)
          train_perplexity = run_epoch(session, m_forward,input, sequence_length, target, mode='train')
          train_ppl_list.append(train_perplexity)
          print("Epoch:%d, Iter: %d Train NLL: %.3f" % (epoch, i + 1, train_perplexity))
        for i in range(test_data.length//config.batch_size):
          input, sequence_length, target=test_data(mtest_forward.batch_size, i)
          test_perplexity = run_epoch(session, mtest_forward, input, sequence_length, target, mode='test')
          test_ppl_list.append(test_perplexity)
          print("Epoch:%d, Iter: %d Test NLL: %.3f" % (epoch, i + 1, test_perplexity))
        test_mean=np.mean(test_ppl_list)
        if test_mean<test_mean_old:
          test_mean_old=test_mean
          saver_forward.save(session, config.forward_save_path)
        write_log('train ppl:'+str(np.mean(train_ppl_list))+'\t'+'test ppl:'+str(test_mean), config.forward_log_path)
    
    if config.mode=='backward':
      train_data, test_data = reader.read_data(config.data_path, config.num_steps)
      test_mean_old=15.0
      for epoch in range(config.max_epoch):
        train_ppl_list=[]
        test_ppl_list=[]
      
        for i in range(train_data.length//config.batch_size):
          input, sequence_length, target=train_data(m_backward.batch_size, i)
          input, sequence_length, target=reverse_seq(input, sequence_length, target)
          train_perplexity = run_epoch(session, m_backward,input, sequence_length, target, mode='train')
          train_ppl_list.append(train_perplexity)
          print("Epoch:%d, Iter: %d Train NLL: %.3f" % (epoch, i + 1, train_perplexity))
        for i in range(test_data.length//config.batch_size):
          input, sequence_length, target=test_data(mtest_backward.batch_size, i)
          input, sequence_length, target=reverse_seq(input, sequence_length, target)
          test_perplexity = run_epoch(session, mtest_backward, input, sequence_length, target, mode='test')
          test_ppl_list.append(test_perplexity)
          print("Epoch:%d, Iter: %d Test NLL: %.3f" % (epoch, i + 1, test_perplexity))
        test_mean=np.mean(test_ppl_list)
        if test_mean<test_mean_old:
          test_mean_old=test_mean
          saver_backward.save(session, config.backward_save_path)
        write_log('train ppl:'+str(np.mean(train_ppl_list))+'\t'+'test ppl:'+str(test_mean), config.backward_log_path)
  
    if config.mode=='use':
      sim=config.sim
      #keyword stable
      sta_vec=list(np.zeros([config.num_steps-1]))

      saver_forward.restore(session, config.forward_save_path)
      saver_backward.restore(session, config.backward_save_path)
      config.shuffle=False
      if config.keyboard_input==True:
        key_input=raw_input('please input a sentence in lower case\n')
        if key_input=='':
          use_data = reader.read_data_use(config.use_data_path, config.num_steps)
        else:
          key_input=key_input.split()
          key_input=sen2id(key_input)
          use_data = reader.array_data([key_input], config.num_steps, config.dict_size)
      else:
        use_data, sta_vec_list = reader.read_data_use(config.use_data_path, config.num_steps)
      config.batch_size=1
      #use_data.length=1 #######################################
      for sen_id in range(use_data.length):
        if config.keyboard_input==False:
          sta_vec=sta_vec_list[sen_id%len(sta_vec)]
        print(sta_vec)
        input, sequence_length, _=use_data(1, sen_id)
        input_original=input[0]
        for i in range(1,config.num_steps):
          if input[0][i]>config.rare_since and  input[0][i]<config.dict_size:
            sta_vec[i-1]=1
        pos=0

        for iter in range(config.sample_time):
        #ind is the index of the selected word, regardless of the beginning token.
          
          ind=pos%(sequence_length[0]-1)
          action=choose_action(config.action_prob)
          #tem
          print(' '.join(id2sen(input[0])))
          if iter in config.record_time:
            with open(config.use_output_path+str(iter), 'a') as g:
              g.write(' '.join(id2sen(input[0]))+'\n')
          #tem_end
          #print(sta_vec, sequence_length[0], ind)
          '''
          if sta_vec[ind]==1 and action in [0, 2]:                  #stop skipping words
            action=3
          '''
        #change word
          if action==0: 
            prob_old=run_epoch(session, mtest_forward, input, sequence_length, mode='use')
            if config.double_LM==True:
              input_backward, _, _ =reverse_seq(input, sequence_length, input)
              prob_old=(prob_old+run_epoch(session, mtest_backward, input_backward, sequence_length, mode='use'))*0.5

            tem=1
            for j in range(sequence_length[0]-1):
              tem*=prob_old[0][j][input[0][j+1]]
            tem*=prob_old[0][j+1][config.dict_size+1]
            prob_old_prob=tem
            if sim!=None:
              similarity_old=similarity(input[0], input_original, sta_vec)
              prob_old_prob*=similarity_old
            else:
              similarity_old=-1
            input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(input, sequence_length, ind, mode=action)
            prob_forward=run_epoch(session, mtest_forward, input_forward, sequence_length_forward, mode='use')[0, ind%(sequence_length[0]-1),:]
            prob_backward=run_epoch(session, mtest_backward, input_backward, sequence_length_backward, mode='use')[0, sequence_length[0]-1-ind%(sequence_length[0]-1),:]
            prob_mul=(prob_forward*prob_backward)
            input_candidate, sequence_length_candidate=generate_candidate_input(input, sequence_length, ind, prob_mul, config.search_size, mode=action)
            prob_candidate_pre=run_epoch(session, mtest_forward, input_candidate, sequence_length_candidate, mode='use')
            if config.double_LM==True:
              input_candidate_backward, _, _ =reverse_seq(input_candidate, sequence_length_candidate, input_candidate)
              prob_candidate_pre=(prob_candidate_pre+run_epoch(session, mtest_backward, input_candidate_backward, sequence_length_candidate, mode='use'))*0.5
            prob_candidate=[]
            for i in range(config.search_size):
              tem=1
              for j in range(sequence_length[0]-1):
                tem*=prob_candidate_pre[i][j][input_candidate[i][j+1]]
              tem*=prob_candidate_pre[i][j+1][config.dict_size+1]
              prob_candidate.append(tem)
          
            prob_candidate=np.array(prob_candidate)
            #similarity_candidate=np.array([similarity(x, input_original) for x in input_candidate])
            if sim!=None:
              similarity_candidate=similarity_batch(input_candidate, input_original,sta_vec)
              prob_candidate=prob_candidate*similarity_candidate
            prob_candidate_norm=normalize(prob_candidate)
            prob_candidate_ind=sample_from_candidate(prob_candidate_norm)
            prob_candidate_prob=prob_candidate[prob_candidate_ind]
            if input_candidate[prob_candidate_ind][ind+1]<config.dict_size and ( prob_candidate_prob>prob_old_prob*config.threshold or just_acc()==0):
              input=input_candidate[prob_candidate_ind:prob_candidate_ind+1]
            pos+=1
            #old_place=len(prob_mul)-list(np.argsort(prob_mul)).index(input[0][ind+1])
            #write_log('step:'+str(iter)+'action:0 prob_old:'+str(prob_old_prob)+' prob_new:'+str(prob_candidate_prob)+' '+str(old_place)+' '+str(sta_vec.index(1))+' '+str(ind), config.use_log_path)
            print('action:0', 1, prob_old_prob, prob_candidate_prob, prob_candidate_norm[prob_candidate_ind], similarity_old)

          #add word
          if action==1: 
            if sequence_length[0]>=config.num_steps:
              action=3
            else:
              input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(input, sequence_length, ind, mode=action)
              prob_forward=run_epoch(session, mtest_forward, input_forward, sequence_length_forward, mode='use')[0, ind%(sequence_length[0]-1),:]
              prob_backward=run_epoch(session, mtest_backward, input_backward, sequence_length_backward, mode='use')[0, sequence_length[0]-1-ind%(sequence_length[0]-1),:]
              prob_mul=(prob_forward*prob_backward)
              input_candidate, sequence_length_candidate=generate_candidate_input(input, sequence_length, ind, prob_mul, config.search_size, mode=action)
              prob_candidate_pre=run_epoch(session, mtest_forward, input_candidate, sequence_length_candidate, mode='use')
              if config.double_LM==True:
                input_candidate_backward, _, _ =reverse_seq(input_candidate, sequence_length_candidate, input_candidate)
                prob_candidate_pre=(prob_candidate_pre+run_epoch(session, mtest_backward, input_candidate_backward, sequence_length_candidate, mode='use'))*0.5

              prob_candidate=[]
              for i in range(config.search_size):
                tem=1
                for j in range(sequence_length_candidate[0]-1):
                  tem*=prob_candidate_pre[i][j][input_candidate[i][j+1]]
                tem*=prob_candidate_pre[i][j+1][config.dict_size+1]
                prob_candidate.append(tem)
              prob_candidate=np.array(prob_candidate)
              #similarity_candidate=np.array([similarity(x, input_original) for x in input_candidate])
              if sim!=None:
                similarity_candidate=similarity_batch(input_candidate, input_original,sta_vec)
                prob_candidate=prob_candidate*similarity_candidate
              prob_candidate_norm=normalize(prob_candidate)

              prob_candidate_ind=sample_from_candidate(prob_candidate_norm)
              prob_candidate_prob=prob_candidate[prob_candidate_ind]

              prob_old=run_epoch(session, mtest_forward, input, sequence_length, mode='use')
              if config.double_LM==True:
                input_backward, _, _ =reverse_seq(input, sequence_length, input)
                prob_old=(prob_old+run_epoch(session, mtest_backward, input_backward, sequence_length, mode='use'))*0.5

              tem=1
              for j in range(sequence_length[0]-1):
                tem*=prob_old[0][j][input[0][j+1]]
              tem*=prob_old[0][j+1][config.dict_size+1]
            
              prob_old_prob=tem
              if sim!=None:
                similarity_old=similarity(input[0], input_original,sta_vec)
                prob_old_prob=prob_old_prob*similarity_old
              else:
                similarity_old=-1
              alpha=min(1, prob_candidate_prob*config.action_prob[2]/(prob_old_prob*config.action_prob[1]*prob_candidate_norm[prob_candidate_ind]))
              #alpha=min(1, prob_candidate_prob*config.action_prob[2]/(prob_old_prob*config.action_prob[1]))
              print ('action:1',alpha, prob_old_prob,prob_candidate_prob, prob_candidate_norm[prob_candidate_ind], similarity_old)
            
              if choose_action([alpha, 1-alpha])==0 and input_candidate[prob_candidate_ind][ind]<config.dict_size and (prob_candidate_prob>prob_old_prob* config.threshold or just_acc()==0):
              #write_log('step:'+str(iter)+'action:1 prob_old:'+str(prob_old_prob)+' prob_new:'+str(prob_candidate_prob)+' '+str(sta_vec.index(1))+' '+str(ind), config.use_log_path)
                input=input_candidate[prob_candidate_ind:prob_candidate_ind+1]
                sequence_length+=1
                pos+=2
                sta_vec.insert(ind, 0.0)
                del(sta_vec[-1])
              else:
                action=3
       
       
        #delete word
          if action==2:
            if sequence_length[0]<=2:
              action=3
            else:

              prob_old=run_epoch(session, mtest_forward, input, sequence_length, mode='use')
              if config.double_LM==True:
                input_backward, _, _ =reverse_seq(input, sequence_length, input)
                prob_old=(prob_old+run_epoch(session, mtest_backward, input_backward, sequence_length, mode='use'))*0.5

              tem=1
              for j in range(sequence_length[0]-1):
                tem*=prob_old[0][j][input[0][j+1]]
              tem*=prob_old[0][j+1][config.dict_size+1]
              prob_old_prob=tem
              if sim!=None:
                similarity_old=similarity(input[0], input_original,sta_vec)
                prob_old_prob=prob_old_prob*similarity_old
              else:
                similarity_old=-1
              input_candidate, sequence_length_candidate=generate_candidate_input(input, sequence_length, ind, None , config.search_size, mode=2)
              prob_new=run_epoch(session, mtest_forward, input_candidate, sequence_length_candidate, mode='use')
              tem=1
              for j in range(sequence_length_candidate[0]-1):
                tem*=prob_new[0][j][input_candidate[0][j+1]]
              tem*=prob_new[0][j+1][config.dict_size+1]
              prob_new_prob=tem
              if sim!=None:
                similarity_new=similarity_batch(input_candidate, input_original,sta_vec)
                prob_new_prob=prob_new_prob*similarity_new
            
              input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(input, sequence_length, ind, mode=0)
              prob_forward=run_epoch(session, mtest_forward, input_forward, sequence_length_forward, mode='use')[0, ind%(sequence_length[0]-1),:]
              prob_backward=run_epoch(session, mtest_backward, input_backward, sequence_length_backward, mode='use')[0, sequence_length[0]-1-ind%(sequence_length[0]-1),:]
              prob_mul=(prob_forward*prob_backward)
              input_candidate, sequence_length_candidate=generate_candidate_input(input, sequence_length, ind, prob_mul, config.search_size, mode=0)
              prob_candidate_pre=run_epoch(session, mtest_forward, input_candidate, sequence_length_candidate, mode='use')
              if config.double_LM==True:
                input_candidate_backward, _, _ =reverse_seq(input_candidate, sequence_length_candidate, input_candidate)
                prob_candidate_pre=(prob_candidate_pre+run_epoch(session, mtest_backward, input_candidate_backward, sequence_length_candidate, mode='use'))*0.5

              prob_candidate=[]
              for i in range(config.search_size):
                tem=1
                for j in range(sequence_length[0]-1):
                  tem*=prob_candidate_pre[i][j][input_candidate[i][j+1]]
                tem*=prob_candidate_pre[i][j+1][config.dict_size+1]
                prob_candidate.append(tem)
              prob_candidate=np.array(prob_candidate)
            
              #similarity_candidate=np.array([similarity(x, input_original) for x in input_candidate])
              if sim!=None:
                similarity_candidate=similarity_batch(input_candidate, input_original,sta_vec)
                prob_candidate=prob_candidate*similarity_candidate
            
              #####There is a unsolved problem
              prob_candidate_norm=normalize(prob_candidate)
              if input[0] in input_candidate:
                for candidate_ind in range(len(input_candidate)):
                  if input[0] in input_candidate[candidate_ind: candidate_ind+1]:
                    break
                  pass
                alpha=min(prob_candidate_norm[candidate_ind]*prob_new_prob*config.action_prob[1]/(config.action_prob[2]*prob_old_prob), 1)
              else:
                pass
                alpha=0
              #alpha=min(prob_new_prob*config.action_prob[1]/(config.action_prob[2]*prob_old_prob), 1)
              print('action:2', alpha, prob_old_prob, prob_new_prob, prob_candidate_norm[candidate_ind], similarity_old)
             
              if choose_action([alpha, 1-alpha])==0 and (prob_new_prob> prob_old_prob*config.threshold or just_acc()==0):
                #write_log('step:'+str(iter)+'action:2 prob_old:'+str(prob_old_prob)+' prob_new:'+str(prob_new_prob)+' '+str(sta_vec.index(1))+' '+str(ind), config.use_log_path)
                input=np.concatenate([input[:,:ind+1], input[:,ind+2:], input[:,:1]*0+config.dict_size+1], axis=1)
                sequence_length-=1
                pos+=0
                del(sta_vec[ind])
                sta_vec.append(0)
              else:
                action=3
          #do nothing
          if action==3:
            #write_log('step:'+str(iter)+'action:3', config.use_log_path)
            pos+=1
Ejemplo n.º 2
0
import tensorflow as tf
import reader

from config import config
config = config()

from utils import *

from model import LangModel

m_forward = LangModel(config.forward_save_path)
m_backward = LangModel(config.backward_save_path)
m_forward.restore()
m_backward.restore()

dataset, sequence_lengths, sta_vec_list = reader.read_data_use(
    config.use_data_path, config.num_steps)
config.batch_size = 1

for sen_id, data in enumerate(dataset.as_numpy_iterator()
                              ):  # For each sentence in the list of sentences
    input = data[0]
    sequence_length = sequence_lengths[sen_id]

    if config.keyboard_input == False:
        sta_vec = sta_vec_list[sen_id % (config.num_steps - 1)]
    print(sta_vec)

    pos = 0
    outputs = []
    output_p = []
    for iter in range(config.sample_time):
Ejemplo n.º 3
0
def main(_):
    if os.path.exists(config.forward_log_path) and config.mode == 'forward':
        os.system('rm ' + config.forward_log_path)
    if os.path.exists(config.backward_log_path) and config.mode == 'backward':
        os.system('rm ' + config.backward_log_path)
    if os.path.exists(config.use_output_path):
        os.system('rm ' + config.use_output_path)
    for item in config.record_time:
        if os.path.exists(config.use_output_path + str(item)):
            os.system('rm ' + config.use_output_path + str(item))
    if os.path.exists(config.use_log_path):
        os.system('rm ' + config.use_log_path)
    if config.mode == 'forward' or config.mode == 'use':
        with tf.name_scope("forward_train"):
            with tf.variable_scope("forward", reuse=None):
                m_forward = PTBModel(is_training=True)
        with tf.name_scope("forward_test"):
            with tf.variable_scope("forward", reuse=True):
                mtest_forward = PTBModel(is_training=False)
        var = tf.trainable_variables()
        var_forward = [x for x in var if x.name.startswith('forward')]
        saver_forward = tf.train.Saver(var_forward, max_to_keep=1)
    if config.mode == 'backward' or config.mode == 'use':
        with tf.name_scope("backward_train"):
            with tf.variable_scope("backward", reuse=None):
                m_backward = PTBModel(is_training=True)
        with tf.name_scope("backward_test"):
            with tf.variable_scope("backward", reuse=True):
                mtest_backward = PTBModel(is_training=False)
        var = tf.trainable_variables()
        var_backward = [x for x in var if x.name.startswith('backward')]
        saver_backward = tf.train.Saver(var_backward, max_to_keep=1)

    init = tf.global_variables_initializer()

    with tf.Session() as session:
        session.run(init)
        if config.mode == 'forward':
            #train forward language model
            train_data, test_data = reader.read_data(config.data_path,
                                                     config.num_steps)
            test_mean_old = 15.0

            for epoch in range(config.max_epoch):
                train_ppl_list = []
                test_ppl_list = []
                for i in range(train_data.length // config.batch_size):
                    input, sequence_length, target = train_data(
                        m_forward.batch_size, i)
                    train_perplexity = run_epoch(session,
                                                 m_forward,
                                                 input,
                                                 sequence_length,
                                                 target,
                                                 mode='train')
                    train_ppl_list.append(train_perplexity)
                    print("Epoch:%d, Iter: %d Train NLL: %.3f" %
                          (epoch, i + 1, train_perplexity))
                for i in range(test_data.length // config.batch_size):
                    input, sequence_length, target = test_data(
                        mtest_forward.batch_size, i)
                    test_perplexity = run_epoch(session,
                                                mtest_forward,
                                                input,
                                                sequence_length,
                                                target,
                                                mode='test')
                    test_ppl_list.append(test_perplexity)
                    print("Epoch:%d, Iter: %d Test NLL: %.3f" %
                          (epoch, i + 1, test_perplexity))
                test_mean = np.mean(test_ppl_list)
                if test_mean < test_mean_old:
                    test_mean_old = test_mean
                    saver_forward.save(session, config.forward_save_path)
                write_log(
                    'train ppl:' + str(np.mean(train_ppl_list)) + '\t' +
                    'test ppl:' + str(test_mean), config.forward_log_path)

        if config.mode == 'backward':
            #train backward language model
            train_data, test_data = reader.read_data(config.data_path,
                                                     config.num_steps)
            test_mean_old = 15.0
            for epoch in range(config.max_epoch):
                train_ppl_list = []
                test_ppl_list = []

                for i in range(train_data.length // config.batch_size):
                    input, sequence_length, target = train_data(
                        m_backward.batch_size, i)
                    input, sequence_length, target = reverse_seq(
                        input, sequence_length, target)
                    train_perplexity = run_epoch(session,
                                                 m_backward,
                                                 input,
                                                 sequence_length,
                                                 target,
                                                 mode='train')
                    train_ppl_list.append(train_perplexity)
                    print("Epoch:%d, Iter: %d Train NLL: %.3f" %
                          (epoch, i + 1, train_perplexity))
                for i in range(test_data.length // config.batch_size):
                    input, sequence_length, target = test_data(
                        mtest_backward.batch_size, i)
                    input, sequence_length, target = reverse_seq(
                        input, sequence_length, target)
                    test_perplexity = run_epoch(session,
                                                mtest_backward,
                                                input,
                                                sequence_length,
                                                target,
                                                mode='test')
                    test_ppl_list.append(test_perplexity)
                    print("Epoch:%d, Iter: %d Test NLL: %.3f" %
                          (epoch, i + 1, test_perplexity))
                test_mean = np.mean(test_ppl_list)
                if test_mean < test_mean_old:
                    test_mean_old = test_mean
                    saver_backward.save(session, config.backward_save_path)
                write_log(
                    'train ppl:' + str(np.mean(train_ppl_list)) + '\t' +
                    'test ppl:' + str(test_mean), config.backward_log_path)

        if config.mode == 'use':
            #CGMH sampling for sentence_correction
            sim = config.sim
            sta_vec = list(np.zeros([config.num_steps - 1]))

            saver_forward.restore(session, config.forward_save_path)
            saver_backward.restore(session, config.backward_save_path)
            config.shuffle = False
            #erroneous sentence input
            if config.keyboard_input == True:
                #input from keyboard if key_input is not empty
                key_input = raw_input('please input a sentence\n')
                if key_input == '':
                    use_data = reader.read_data_use(config.use_data_path,
                                                    config.num_steps)
                else:
                    sta_vec_list = [sen2sta_vec(key_input)]
                    key_input = key_input.split()
                    #key_input=sen2id(key_input)
                    use_data = [key_input]
            else:
                #load keywords from file
                use_data = []
                with open(config.use_data_path) as f:
                    for line in f:
                        use_data.append(line.strip().split())
            config.batch_size = 1

            for sen_id in range(len(use_data)):
                #generate for each sentence
                input_ = use_data[sen_id]
                pos = 0

                for iter in range(config.sample_time):
                    #ind is the index of the selected word, regardless of the beginning token.
                    sta_vec = sen2sta_vec(' '.join(input_))
                    input__ = reader.array_data([sen2id(input_)],
                                                config.num_steps,
                                                config.dict_size)
                    input, sequence_length, _ = input__(1, 0)
                    input_original = input[0]

                    ind = pos % (sequence_length[0] - 1)
                    print(' '.join(input_))

                    if iter in config.record_time:
                        with open(config.use_output_path + str(iter),
                                  'a') as g:
                            g.write(' '.join(input_) + '\n')

                    if True:
                        prob_old = run_epoch(session,
                                             mtest_forward,
                                             input,
                                             sequence_length,
                                             mode='use')
                        if config.double_LM == True:
                            input_backward, _, _ = reverse_seq(
                                input, sequence_length, input)
                            prob_old = (prob_old + run_epoch(session,
                                                             mtest_backward,
                                                             input_backward,
                                                             sequence_length,
                                                             mode='use')) * 0.5

                        tem = 1
                        for j in range(sequence_length[0] - 1):
                            tem *= prob_old[0][j][input[0][j + 1]]
                        tem *= prob_old[0][j + 1][config.dict_size + 1]
                        prob_old_prob = tem

                        if sim != None:
                            similarity_old = similarity(
                                input[0], input_original)
                            prob_old_prob *= similarity_old
                        else:
                            similarity_old = -1

                        input_candidate_ = generate_change_candidate(
                            input_, ind)
                        tem = reader.array_data(
                            [sen2id(x) for x in input_candidate_],
                            config.num_steps, config.dict_size)
                        input_candidate, sequence_length_candidate, _ = tem(
                            len(input_candidate_), 0)

                        prob_candidate_pre = run_epoch(
                            session,
                            mtest_forward,
                            input_candidate,
                            sequence_length_candidate,
                            mode='use')
                        if config.double_LM == True:
                            input_candidate_backward, _, _ = reverse_seq(
                                input_candidate, sequence_length_candidate,
                                input_candidate)
                            prob_candidate_pre = (
                                prob_candidate_pre +
                                run_epoch(session,
                                          mtest_backward,
                                          input_candidate_backward,
                                          sequence_length_candidate,
                                          mode='use')) * 0.5
                        prob_candidate = []
                        for i in range(len(input_candidate_)):
                            tem = 1
                            for j in range(sequence_length[0] - 1):
                                tem *= prob_candidate_pre[i][j][
                                    input_candidate[i][j + 1]]
                            tem *= prob_candidate_pre[i][j +
                                                         1][config.dict_size +
                                                            1]
                            prob_candidate.append(tem)

                        prob_candidate = np.array(prob_candidate)
                        if sim != None:
                            similarity_candidate = similarity_batch(
                                input_candidate, input_original)
                            prob_candidate = prob_candidate * similarity_candidate
                        prob_candidate_norm = normalize(prob_candidate)
                        prob_candidate_ind = sample_from_candidate(
                            prob_candidate_norm)
                        prob_change_prob = prob_candidate[prob_candidate_ind]
                        input_change_ = input_candidate_[prob_candidate_ind]

                    #word replacement (action: 0)
                    if True:
                        if False:
                            pass
                        else:
                            input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(
                                input, sequence_length, ind, mode=0)
                            prob_forward = run_epoch(
                                session,
                                mtest_forward,
                                input_forward,
                                sequence_length_forward,
                                mode='use')[0,
                                            ind % (sequence_length[0] - 1), :]
                            prob_backward = run_epoch(
                                session,
                                mtest_backward,
                                input_backward,
                                sequence_length_backward,
                                mode='use')[0, sequence_length[0] - 1 - ind %
                                            (sequence_length[0] - 1), :]
                            prob_mul = (prob_forward * prob_backward)
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                prob_mul,
                                config.search_size,
                                mode=1)
                            prob_candidate_pre = run_epoch(
                                session,
                                mtest_forward,
                                input_candidate,
                                sequence_length_candidate,
                                mode='use')
                            if config.double_LM == True:
                                input_candidate_backward, _, _ = reverse_seq(
                                    input_candidate, sequence_length_candidate,
                                    input_candidate)
                                prob_candidate_pre = (
                                    prob_candidate_pre +
                                    run_epoch(session,
                                              mtest_backward,
                                              input_candidate_backward,
                                              sequence_length_candidate,
                                              mode='use')) * 0.5

                            prob_candidate = []
                            for i in range(config.search_size):
                                tem = 1
                                for j in range(sequence_length_candidate[0] -
                                               1):
                                    tem *= prob_candidate_pre[i][j][
                                        input_candidate[i][j + 1]]
                                tem *= prob_candidate_pre[i][j + 1][
                                    config.dict_size + 1]
                                prob_candidate.append(tem)
                            prob_candidate = np.array(prob_candidate)
                            if config.sim_word == True:
                                similarity_candidate = similarity_batch(
                                    input_candidate[:, ind + 1:ind + 2],
                                    input_original[ind + 1:ind + 2])
                                prob_candidate = prob_candidate * similarity_candidate
                            prob_candidate_norm = normalize(prob_candidate)

                            prob_candidate_ind = sample_from_candidate(
                                prob_candidate_norm)
                            prob_candidate_prob = prob_candidate[
                                prob_candidate_ind]

                            prob_changeanother_prob = prob_candidate_prob
                            word = id2sen(
                                input_candidate[prob_candidate_ind])[ind]
                            input_changeanother_ = input_[:ind] + [
                                word
                            ] + input_[ind + 1:]

                    #word insertion(action:1)
                    if True:
                        if sequence_length[0] >= config.num_steps:
                            prob_add_prob = 0
                            pass
                        else:
                            input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(
                                input, sequence_length, ind, mode=1)
                            prob_forward = run_epoch(
                                session,
                                mtest_forward,
                                input_forward,
                                sequence_length_forward,
                                mode='use')[0,
                                            ind % (sequence_length[0] - 1), :]
                            prob_backward = run_epoch(
                                session,
                                mtest_backward,
                                input_backward,
                                sequence_length_backward,
                                mode='use')[0, sequence_length[0] - 1 - ind %
                                            (sequence_length[0] - 1), :]
                            prob_mul = (prob_forward * prob_backward)
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                prob_mul,
                                config.search_size,
                                mode=1)
                            prob_candidate_pre = run_epoch(
                                session,
                                mtest_forward,
                                input_candidate,
                                sequence_length_candidate,
                                mode='use')
                            if config.double_LM == True:
                                input_candidate_backward, _, _ = reverse_seq(
                                    input_candidate, sequence_length_candidate,
                                    input_candidate)
                                prob_candidate_pre = (
                                    prob_candidate_pre +
                                    run_epoch(session,
                                              mtest_backward,
                                              input_candidate_backward,
                                              sequence_length_candidate,
                                              mode='use')) * 0.5

                            prob_candidate = []
                            for i in range(config.search_size):
                                tem = 1
                                for j in range(sequence_length_candidate[0] -
                                               1):
                                    tem *= prob_candidate_pre[i][j][
                                        input_candidate[i][j + 1]]
                                tem *= prob_candidate_pre[i][j + 1][
                                    config.dict_size + 1]
                                prob_candidate.append(tem)
                            prob_candidate = np.array(prob_candidate)
                            #similarity_candidate=np.array([similarity(x, input_original) for x in input_candidate])
                            if sim != None:
                                similarity_candidate = similarity_batch(
                                    input_candidate, input_original)
                                prob_candidate = prob_candidate * similarity_candidate
                            prob_candidate_norm = normalize(prob_candidate)

                            prob_candidate_ind = sample_from_candidate(
                                prob_candidate_norm)
                            prob_candidate_prob = prob_candidate[
                                prob_candidate_ind]

                            prob_add_prob = prob_candidate_prob
                            word = id2sen(
                                input_candidate[prob_candidate_ind])[ind]
                            input_add_ = input_[:ind] + [word] + input_[ind:]

                #word deletion(action: 2)
                    if True:
                        if sequence_length[0] <= 2:
                            prob_delete_prob = 0
                            pass
                        else:
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                None,
                                config.search_size,
                                mode=2)
                            prob_new = run_epoch(session,
                                                 mtest_forward,
                                                 input_candidate,
                                                 sequence_length_candidate,
                                                 mode='use')
                            tem = 1
                            for j in range(sequence_length_candidate[0] - 1):
                                tem *= prob_new[0][j][input_candidate[0][j +
                                                                         1]]
                            tem *= prob_new[0][j + 1][config.dict_size + 1]
                            prob_new_prob = tem
                            if sim != None:
                                similarity_new = similarity_batch(
                                    input_candidate, input_original)
                                prob_new_prob = prob_new_prob * similarity_new
                            prob_delete_prob = prob_new_prob
                        input_delete_ = input_[:ind] + input_[ind + 1:]
                    b = np.argmax([
                        prob_old_prob, prob_change_prob,
                        prob_changeanother_prob * 0.3, prob_add_prob * 0.1,
                        prob_delete_prob * 0.001
                    ])
                    print([
                        prob_old_prob, prob_change_prob,
                        prob_changeanother_prob, prob_add_prob,
                        prob_delete_prob
                    ])
                    print([
                        input_, input_change_, input_changeanother_,
                        input_add_, input_delete_
                    ])
                    input_ = [
                        input_, input_change_, input_changeanother_,
                        input_add_, input_delete_
                    ][b]
                    pos += 1
Ejemplo n.º 4
0
def main(_):
    if os.path.exists(config.forward_log_path) and config.mode == 'forward':
        os.system('rm ' + config.forward_log_path)
    if os.path.exists(config.backward_log_path) and config.mode == 'backward':
        os.system('rm ' + config.backward_log_path)
    if os.path.exists(config.use_output_path):
        os.system('rm ' + config.use_output_path)
    if os.path.exists(config.use_output_path):
        os.system('rm ' + config.use_output_path)
    if os.path.exists(config.use_log_path):
        os.system('rm ' + config.use_log_path)
    if config.mode == 'forward' or config.mode == 'use':
        with tf.name_scope("forward_train"):
            with tf.variable_scope("forward", reuse=None):
                m_forward = PTBModel(is_training=True)
        with tf.name_scope("forward_test"):
            with tf.variable_scope("forward", reuse=True):
                mtest_forward = PTBModel(is_training=False)
        var = tf.trainable_variables()
        var_forward = [x for x in var if x.name.startswith('forward')]
        saver_forward = tf.train.Saver(var_forward, max_to_keep=1)
    if config.mode == 'backward' or config.mode == 'use':
        with tf.name_scope("backward_train"):
            with tf.variable_scope("backward", reuse=None):
                m_backward = PTBModel(is_training=True)
        with tf.name_scope("backward_test"):
            with tf.variable_scope("backward", reuse=True):
                mtest_backward = PTBModel(is_training=False)
        var = tf.trainable_variables()
        var_backward = [x for x in var if x.name.startswith('backward')]
        saver_backward = tf.train.Saver(var_backward, max_to_keep=1)

    init = tf.global_variables_initializer()

    configs = tf.ConfigProto()
    configs.gpu_options.allow_growth = True
    with tf.Session(config=configs) as session:
        session.run(init)
        if config.mode == 'forward':
            #train forward language model
            train_data, test_data = reader.read_data(config.data_path,
                                                     config.num_steps)
            test_mean_old = 15.0

            for epoch in range(config.max_epoch):
                train_ppl_list = []
                test_ppl_list = []
                for i in range(train_data.length // config.batch_size):
                    input, sequence_length, target = train_data(
                        m_forward.batch_size, i)
                    train_perplexity = run_epoch(session,
                                                 m_forward,
                                                 input,
                                                 sequence_length,
                                                 target,
                                                 mode='train')
                    train_ppl_list.append(train_perplexity)
                    print("Epoch:%d, Iter: %d Train NLL: %.3f" %
                          (epoch, i + 1, train_perplexity))
                for i in range(test_data.length // config.batch_size):
                    input, sequence_length, target = test_data(
                        mtest_forward.batch_size, i)
                    test_perplexity = run_epoch(session,
                                                mtest_forward,
                                                input,
                                                sequence_length,
                                                target,
                                                mode='test')
                    test_ppl_list.append(test_perplexity)
                    print("Epoch:%d, Iter: %d Test NLL: %.3f" %
                          (epoch, i + 1, test_perplexity))
                test_mean = np.mean(test_ppl_list)
                if test_mean < test_mean_old:
                    test_mean_old = test_mean
                    saver_forward.save(session, config.forward_save_path)
                write_log(
                    'train ppl:' + str(np.mean(train_ppl_list)) + '\t' +
                    'test ppl:' + str(test_mean), config.forward_log_path)

        if config.mode == 'backward':
            #train backward language model
            train_data, test_data = reader.read_data(config.data_path,
                                                     config.num_steps)
            test_mean_old = 15.0
            for epoch in range(config.max_epoch):
                train_ppl_list = []
                test_ppl_list = []

                for i in range(train_data.length // config.batch_size):
                    input, sequence_length, target = train_data(
                        m_backward.batch_size, i)
                    input, sequence_length, target = reverse_seq(
                        input, sequence_length, target)
                    train_perplexity = run_epoch(session,
                                                 m_backward,
                                                 input,
                                                 sequence_length,
                                                 target,
                                                 mode='train')
                    train_ppl_list.append(train_perplexity)
                    print("Epoch:%d, Iter: %d Train NLL: %.3f" %
                          (epoch, i + 1, train_perplexity))
                for i in range(test_data.length // config.batch_size):
                    input, sequence_length, target = test_data(
                        mtest_backward.batch_size, i)
                    input, sequence_length, target = reverse_seq(
                        input, sequence_length, target)
                    test_perplexity = run_epoch(session,
                                                mtest_backward,
                                                input,
                                                sequence_length,
                                                target,
                                                mode='test')
                    test_ppl_list.append(test_perplexity)
                    print("Epoch:%d, Iter: %d Test NLL: %.3f" %
                          (epoch, i + 1, test_perplexity))
                test_mean = np.mean(test_ppl_list)
                if test_mean < test_mean_old:
                    test_mean_old = test_mean
                    saver_backward.save(session, config.backward_save_path)
                write_log(
                    'train ppl:' + str(np.mean(train_ppl_list)) + '\t' +
                    'test ppl:' + str(test_mean), config.backward_log_path)

        if config.mode == 'use':
            #CGMH sampling for key_gen
            sim = config.sim
            saver_forward.restore(session, config.forward_save_path)
            saver_backward.restore(session, config.backward_save_path)
            config.shuffle = False

            #keyword input
            if config.keyboard_input == True:
                #input from keyboard if key_input is not empty
                key_input = raw_input('please input a sentence\n')
                if key_input == '':
                    use_data = reader.read_data_use(config.use_data_path,
                                                    config.num_steps)
                else:
                    key_input = key_input.split()
                    key_input = sen2id(key_input)
                    sta_vec = list(np.zeros([config.num_steps - 1]))
                    for i in range(len(key_input)):
                        sta_vec[i] = 1
                    use_data = reader.array_data([key_input], config.num_steps,
                                                 config.dict_size)
            else:
                #load keywords from file
                use_data, sta_vec_list = reader.read_data_use(
                    config.use_data_path, config.num_steps)
            config.batch_size = 1

            for sen_id in range(use_data.length):
                #generate for each sequence of keywords
                if config.keyboard_input == False:
                    sta_vec = sta_vec_list[sen_id % (config.num_steps - 1)]

                print(sta_vec)

                input, sequence_length, _ = use_data(1, sen_id)
                input_original = input[0]

                pos = 0
                outputs = []
                output_p = []
                for iter in range(config.sample_time):
                    #ind is the index of the selected word, regardless of the beginning token.
                    #sample config.sample_time times for each set of keywords
                    config.sample_prior = [1, 10.0 / sequence_length[0], 1, 1]
                    if iter % 20 < 10:
                        config.threshold = 0
                    else:
                        config.threshold = 0.5
                    ind = pos % (sequence_length[0])
                    action = choose_action(config.action_prob)
                    print(' '.join(id2sen(input[0])))

                    if sta_vec[ind] == 1 and action in [0, 2]:
                        #skip words that we do not change(original keywords)
                        action = 3

                    #word replacement (action: 0)
                    if action == 0 and ind < sequence_length[0] - 1:
                        prob_old = run_epoch(session,
                                             mtest_forward,
                                             input,
                                             sequence_length,
                                             mode='use')
                        if config.double_LM == True:
                            input_backward, _, _ = reverse_seq(
                                input, sequence_length, input)
                            prob_old = (prob_old + run_epoch(session,
                                                             mtest_backward,
                                                             input_backward,
                                                             sequence_length,
                                                             mode='use')) * 0.5

                        tem = 1
                        for j in range(sequence_length[0] - 1):
                            tem *= prob_old[0][j][input[0][j + 1]]
                        tem *= prob_old[0][j + 1][config.dict_size + 1]
                        prob_old_prob = tem

                        if sim != None:
                            similarity_old = similarity(
                                input[0], input_original, sta_vec)
                            prob_old_prob *= similarity_old
                        else:
                            similarity_old = -1
                        input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(
                            input, sequence_length, ind, mode=action)
                        prob_forward = run_epoch(
                            session,
                            mtest_forward,
                            input_forward,
                            sequence_length_forward,
                            mode='use')[0, ind % (sequence_length[0] - 1), :]
                        prob_backward = run_epoch(
                            session,
                            mtest_backward,
                            input_backward,
                            sequence_length_backward,
                            mode='use')[0, sequence_length[0] - 1 - ind %
                                        (sequence_length[0] - 1), :]
                        prob_mul = (prob_forward * prob_backward)
                        input_candidate, sequence_length_candidate = generate_candidate_input(
                            input,
                            sequence_length,
                            ind,
                            prob_mul,
                            config.search_size,
                            mode=action)
                        prob_candidate_pre = run_epoch(
                            session,
                            mtest_forward,
                            input_candidate,
                            sequence_length_candidate,
                            mode='use')
                        if config.double_LM == True:
                            input_candidate_backward, _, _ = reverse_seq(
                                input_candidate, sequence_length_candidate,
                                input_candidate)
                            prob_candidate_pre = (
                                prob_candidate_pre +
                                run_epoch(session,
                                          mtest_backward,
                                          input_candidate_backward,
                                          sequence_length_candidate,
                                          mode='use')) * 0.5
                        prob_candidate = []
                        for i in range(config.search_size):
                            tem = 1
                            for j in range(sequence_length[0] - 1):
                                tem *= prob_candidate_pre[i][j][
                                    input_candidate[i][j + 1]]
                            tem *= prob_candidate_pre[i][j +
                                                         1][config.dict_size +
                                                            1]
                            prob_candidate.append(tem)

                        prob_candidate = np.array(prob_candidate)
                        if sim != None:
                            similarity_candidate = similarity_batch(
                                input_candidate, input_original, sta_vec)
                            prob_candidate = prob_candidate * similarity_candidate
                        prob_candidate_norm = normalize(prob_candidate)
                        prob_candidate_ind = sample_from_candidate(
                            prob_candidate_norm)
                        prob_candidate_prob = prob_candidate[
                            prob_candidate_ind]
                        if input_candidate[prob_candidate_ind][
                                ind + 1] < config.dict_size and (
                                    prob_candidate_prob > prob_old_prob *
                                    config.threshold or just_acc() == 0):
                            input = input_candidate[
                                prob_candidate_ind:prob_candidate_ind + 1]
                        pos += 1
                        print('action:0', 1, prob_old_prob,
                              prob_candidate_prob,
                              prob_candidate_norm[prob_candidate_ind],
                              similarity_old)
                        if ' '.join(id2sen(input[0])) not in output_p:
                            outputs.append(
                                [' '.join(id2sen(input[0])), prob_old_prob])

                    #word insertion(action:1)
                    if action == 1:
                        if sequence_length[0] >= config.num_steps:
                            action = 3
                        else:
                            input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(
                                input, sequence_length, ind, mode=action)
                            prob_forward = run_epoch(
                                session,
                                mtest_forward,
                                input_forward,
                                sequence_length_forward,
                                mode='use')[0,
                                            ind % (sequence_length[0] - 1), :]
                            prob_backward = run_epoch(
                                session,
                                mtest_backward,
                                input_backward,
                                sequence_length_backward,
                                mode='use')[0, sequence_length[0] - 1 - ind %
                                            (sequence_length[0] - 1), :]
                            prob_mul = (prob_forward * prob_backward)
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                prob_mul,
                                config.search_size,
                                mode=action)
                            prob_candidate_pre = run_epoch(
                                session,
                                mtest_forward,
                                input_candidate,
                                sequence_length_candidate,
                                mode='use')
                            if config.double_LM == True:
                                input_candidate_backward, _, _ = reverse_seq(
                                    input_candidate, sequence_length_candidate,
                                    input_candidate)
                                prob_candidate_pre = (
                                    prob_candidate_pre +
                                    run_epoch(session,
                                              mtest_backward,
                                              input_candidate_backward,
                                              sequence_length_candidate,
                                              mode='use')) * 0.5

                            prob_candidate = []
                            for i in range(config.search_size):
                                tem = 1
                                for j in range(sequence_length_candidate[0] -
                                               1):
                                    tem *= prob_candidate_pre[i][j][
                                        input_candidate[i][j + 1]]
                                tem *= prob_candidate_pre[i][j + 1][
                                    config.dict_size + 1]
                                prob_candidate.append(tem)
                            prob_candidate = np.array(
                                prob_candidate) * config.sample_prior[1]
                            if sim != None:
                                similarity_candidate = similarity_batch(
                                    input_candidate, input_original, sta_vec)
                                prob_candidate = prob_candidate * similarity_candidate
                            prob_candidate_norm = normalize(prob_candidate)

                            prob_candidate_ind = sample_from_candidate(
                                prob_candidate_norm)
                            prob_candidate_prob = prob_candidate[
                                prob_candidate_ind]

                            prob_old = run_epoch(session,
                                                 mtest_forward,
                                                 input,
                                                 sequence_length,
                                                 mode='use')
                            if config.double_LM == True:
                                input_backward, _, _ = reverse_seq(
                                    input, sequence_length, input)
                                prob_old = (prob_old +
                                            run_epoch(session,
                                                      mtest_backward,
                                                      input_backward,
                                                      sequence_length,
                                                      mode='use')) * 0.5

                            tem = 1
                            for j in range(sequence_length[0] - 1):
                                tem *= prob_old[0][j][input[0][j + 1]]
                            tem *= prob_old[0][j + 1][config.dict_size + 1]

                            prob_old_prob = tem
                            if sim != None:
                                similarity_old = similarity(
                                    input[0], input_original, sta_vec)
                                prob_old_prob = prob_old_prob * similarity_old
                            else:
                                similarity_old = -1
                            #alpha is acceptance ratio of current proposal
                            alpha = min(
                                1,
                                prob_candidate_prob * config.action_prob[2] /
                                (prob_old_prob * config.action_prob[1] *
                                 prob_candidate_norm[prob_candidate_ind]))
                            print('action:1', alpha, prob_old_prob,
                                  prob_candidate_prob,
                                  prob_candidate_norm[prob_candidate_ind],
                                  similarity_old)
                            if ' '.join(id2sen(input[0])) not in output_p:
                                outputs.append([
                                    ' '.join(id2sen(input[0])), prob_old_prob
                                ])
                            if choose_action([
                                    alpha, 1 - alpha
                            ]) == 0 and input_candidate[prob_candidate_ind][
                                    ind + 1] < config.dict_size and (
                                        prob_candidate_prob > prob_old_prob *
                                        config.threshold or just_acc() == 0):
                                input = input_candidate[
                                    prob_candidate_ind:prob_candidate_ind + 1]
                                sequence_length += 1
                                pos += 2
                                sta_vec.insert(ind, 0.0)
                                del (sta_vec[-1])
                            else:
                                action = 3

                #word deletion(action: 2)
                    if action == 2 and ind < sequence_length[0] - 1:
                        if sequence_length[0] <= 2:
                            action = 3
                        else:

                            prob_old = run_epoch(session,
                                                 mtest_forward,
                                                 input,
                                                 sequence_length,
                                                 mode='use')
                            if config.double_LM == True:
                                input_backward, _, _ = reverse_seq(
                                    input, sequence_length, input)
                                prob_old = (prob_old +
                                            run_epoch(session,
                                                      mtest_backward,
                                                      input_backward,
                                                      sequence_length,
                                                      mode='use')) * 0.5

                            tem = 1
                            for j in range(sequence_length[0] - 1):
                                tem *= prob_old[0][j][input[0][j + 1]]
                            tem *= prob_old[0][j + 1][config.dict_size + 1]
                            prob_old_prob = tem
                            if sim != None:
                                similarity_old = similarity(
                                    input[0], input_original, sta_vec)
                                prob_old_prob = prob_old_prob * similarity_old
                            else:
                                similarity_old = -1
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                None,
                                config.search_size,
                                mode=2)
                            prob_new = run_epoch(session,
                                                 mtest_forward,
                                                 input_candidate,
                                                 sequence_length_candidate,
                                                 mode='use')
                            tem = 1
                            for j in range(sequence_length_candidate[0] - 1):
                                tem *= prob_new[0][j][input_candidate[0][j +
                                                                         1]]
                            tem *= prob_new[0][j + 1][config.dict_size + 1]
                            prob_new_prob = tem
                            if sim != None:
                                similarity_new = similarity_batch(
                                    input_candidate, input_original, sta_vec)
                                prob_new_prob = prob_new_prob * similarity_new

                            input_forward, input_backward, sequence_length_forward, sequence_length_backward = cut_from_point(
                                input, sequence_length, ind, mode=0)
                            prob_forward = run_epoch(
                                session,
                                mtest_forward,
                                input_forward,
                                sequence_length_forward,
                                mode='use')[0,
                                            ind % (sequence_length[0] - 1), :]
                            prob_backward = run_epoch(
                                session,
                                mtest_backward,
                                input_backward,
                                sequence_length_backward,
                                mode='use')[0, sequence_length[0] - 1 - ind %
                                            (sequence_length[0] - 1), :]
                            prob_mul = (prob_forward * prob_backward)
                            input_candidate, sequence_length_candidate = generate_candidate_input(
                                input,
                                sequence_length,
                                ind,
                                prob_mul,
                                config.search_size,
                                mode=0)
                            prob_candidate_pre = run_epoch(
                                session,
                                mtest_forward,
                                input_candidate,
                                sequence_length_candidate,
                                mode='use')
                            if config.double_LM == True:
                                input_candidate_backward, _, _ = reverse_seq(
                                    input_candidate, sequence_length_candidate,
                                    input_candidate)
                                prob_candidate_pre = (
                                    prob_candidate_pre +
                                    run_epoch(session,
                                              mtest_backward,
                                              input_candidate_backward,
                                              sequence_length_candidate,
                                              mode='use')) * 0.5

                            prob_candidate = []
                            for i in range(config.search_size):
                                tem = 1
                                for j in range(sequence_length[0] - 1):
                                    tem *= prob_candidate_pre[i][j][
                                        input_candidate[i][j + 1]]
                                tem *= prob_candidate_pre[i][j + 1][
                                    config.dict_size + 1]
                                prob_candidate.append(tem)
                            prob_candidate = np.array(prob_candidate)

                            if sim != None:
                                similarity_candidate = similarity_batch(
                                    input_candidate, input_original, sta_vec)
                                prob_candidate = prob_candidate * similarity_candidate

                            #alpha is acceptance ratio of current proposal
                            prob_candidate_norm = normalize(prob_candidate)
                            if input[0] in input_candidate:
                                for candidate_ind in range(
                                        len(input_candidate)):
                                    if input[0] in input_candidate[
                                            candidate_ind:candidate_ind + 1]:
                                        break
                                    pass
                                alpha = min(
                                    prob_candidate_norm[candidate_ind] *
                                    prob_new_prob * config.action_prob[1] /
                                    (config.action_prob[2] * prob_old_prob), 1)
                            else:
                                pass
                                alpha = 0
                            print('action:2', alpha, prob_old_prob,
                                  prob_new_prob,
                                  prob_candidate_norm[candidate_ind],
                                  similarity_old)
                            if ' '.join(id2sen(input[0])) not in output_p:
                                outputs.append([
                                    ' '.join(id2sen(input[0])), prob_old_prob
                                ])
                            if choose_action([
                                    alpha, 1 - alpha
                            ]) == 0 and (prob_new_prob > prob_old_prob *
                                         config.threshold or just_acc() == 0):
                                input = np.concatenate([
                                    input[:, :ind + 1], input[:, ind + 2:],
                                    input[:, :1] * 0 + config.dict_size + 1
                                ],
                                                       axis=1)
                                sequence_length -= 1
                                pos += 0
                                del (sta_vec[ind])
                                sta_vec.append(0)
                            else:
                                action = 3
                    #skip word (action: 3)
                    if action == 3:
                        #write_log('step:'+str(iter)+'action:3', config.use_log_path)
                        pos += 1
                    print(outputs)
                    if outputs != []:
                        output_p.append(outputs[-1][0])

                #choose output from samples
                for num in range(config.min_length, 0, -1):
                    outputss = [x for x in outputs if len(x[0].split()) >= num]
                    print(num, outputss)
                    if outputss != []:
                        break
                if outputss == []:
                    outputss.append([' '.join(id2sen(input[0])), 1])
                outputss = sorted(outputss, key=lambda x: x[1])[::-1]
                with open(config.use_output_path, 'a') as g:
                    g.write(outputss[0][0] + '\n')