def train(config): # --------- configurations --------- # batch_size = config['batch_size'] save_path = config['save'] # path to store model saved_model = config['load'] # None or path if not os.path.exists(save_path): os.makedirs(save_path) # ---------------------------------- # write_config(save_path + '/config.txt', config) # random seed random.seed(config['random_seed']) # np.random.seed(config['random_seed']) batches, vocab_size, src_word2id, tgt_word2id = construct_training_data_batches( config) tgt_id2word = list(tgt_word2id.keys()) params = { 'vocab_src_size': vocab_size['src'], 'vocab_tgt_size': vocab_size['tgt'], 'go_id': tgt_word2id['<go>'], 'eos_id': tgt_word2id['</s>'] } model = EncoderDecoder(config, params) model.build_network() learning_rate = config['learning_rate'] decay_rate = config['decay_rate'] tf_variables = tf.trainable_variables() for i in range(len(tf_variables)): print(tf_variables[i]) # save & restore model saver = tf.train.Saver(max_to_keep=1) if config['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: print('running on the stack...') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: # development only e.g. air202 print('running locally...') os.environ[ 'CUDA_VISIBLE_DEVICES'] = '1' # choose the device (GPU) here sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.allow_growth = True # Whether the GPU memory usage can grow dynamically. sess_config.gpu_options.per_process_gpu_memory_fraction = 0.95 # The fraction of GPU memory that the process can use. else: os.environ['CUDA_VISIBLE_DEVICES'] = '' sess_config = tf.ConfigProto() with tf.Session(config=sess_config) as sess: if saved_model == None: sess.run(tf.global_variables_initializer()) # ------------ load pre-trained embeddings ------------ # if config['load_embedding_src'] != None: src_embedding = sess.run(model.src_word_embeddings) src_embedding_matrix = load_pretrained_embedding( src_word2id, src_embedding, config['load_embedding_src']) sess.run( model.src_word_embeddings.assign(src_embedding_matrix)) if config['load_embedding_tgt'] != None: if config['load_embedding_tgt'] == config[ 'load_embedding_src']: sess.run( model.tgt_word_embeddings.assign(src_embedding_matrix)) else: tgt_embedding = sess.run(model.tgt_word_embeddings) tgt_embedding_matrix = load_pretrained_embedding( tgt_word2id, tgt_embedding, config['load_embedding_tgt']) sess.run( model.tgt_word_embeddings.assign(tgt_embedding_matrix)) # ----------------------------------------------------- # else: new_saver = tf.train.import_meta_graph(saved_model + '.meta') new_saver.restore(sess, saved_model) print('loaded model...', saved_model) # ------------ TensorBoard ------------ # # summary_writer = tf.summary.FileWriter(save_path + '/tfboard/', graph_def=sess.graph_def) # ------------------------------------- # # ------------ To print out some output -------------------- # my_sentences = [ 'this is test . </s>', 'this is confirm my reservation at hotel . </s>', 'playing tennis good for you . </s>', 'when talking about successful longterm business relationships customer services are important element </s>' ] my_sent_ids = [] for my_sentence in my_sentences: ids = [] for word in my_sentence.split(): if word in src_word2id: ids.append(src_word2id[word]) else: ids.append(src_word2id['<unk>']) my_sent_ids.append(ids) my_sent_len = [len(my_sent) for my_sent in my_sent_ids] my_sent_ids = [ ids + [src_word2id['</s>']] * (config['max_sentence_length'] - len(ids)) for ids in my_sent_ids ] infer_dict = { model.src_word_ids: my_sent_ids, model.src_sentence_lengths: my_sent_len, model.dropout: 0.0, model.learning_rate: learning_rate } # ---------------------------------------------------------- # num_epochs = config['num_epochs'] for epoch in range(num_epochs): print("num_batches = ", len(batches)) random.shuffle(batches) epoch_loss = 0 for i, batch in enumerate(batches): feed_dict = { model.src_word_ids: batch['src_word_ids'], model.tgt_word_ids: batch['tgt_word_ids'], model.src_sentence_lengths: batch['src_sentence_lengths'], model.tgt_sentence_lengths: batch['tgt_sentence_lengths'], model.dropout: config['dropout'], model.learning_rate: learning_rate } [_, loss] = sess.run([model.train_op, model.train_loss], feed_dict=feed_dict) epoch_loss += loss if i % 100 == 0: # to print out training status # if config['decoding_method'] != 'beamsearch': # [train_loss, infer_loss] = sess.run([model.train_loss, model.infer_loss], feed_dict=feed_dict) # print("batch: {} --- train_loss: {:.5f} | inf_loss: {:.5f}".format(i, train_loss, infer_loss)) # else: # --- beam search --- # # [train_loss] = sess.run([model.train_loss], feed_dict=feed_dict) # print("BEAMSEARCH - batch: {} --- train_loss: {:.5f}".format(i, train_loss)) print("batch: {} --- avg train loss: {:.5f}".format( i, epoch_loss / (i + 1))) sys.stdout.flush() if i % 500 == 0: [my_translations] = sess.run([model.translations], feed_dict=infer_dict) # pdb.set_trace() for my_sent in my_translations: my_words = [tgt_id2word[id] for id in my_sent] print(' '.join(my_words)) model.increment_counter() learning_rate *= decay_rate print("---------------------------------------------------") print("epoch {} done".format(epoch + 1)) print("total training loss = {}".format(epoch_loss)) print("---------------------------------------------------") if math.isnan(epoch_loss): print("stop training - loss/gradient exploded") break saver.save(sess, save_path + '/model', global_step=epoch)
def adapt(config): if 'X_SGE_CUDA_DEVICE' in os.environ: print('running on the stack...') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: # development only e.g. air202 print('running locally...') os.environ[ 'CUDA_VISIBLE_DEVICES'] = '3' # choose the device (GPU) here sess_config = tf.ConfigProto() batches, vocab_size, src_word2id, tgt_word2id = construct_training_data_batches( config) tgt_id2word = list(tgt_word2id.keys()) params = { 'vocab_src_size': len(src_word2id), 'vocab_tgt_size': len(tgt_word2id), 'go_id': tgt_word2id['<go>'], 'eos_id': tgt_word2id['</s>'] } # build the model model = EncoderDecoder(config, params) model.build_network() # -------- Adaption work -------- # bias_name = 'decoder/decode_with_shared_attention/decoder/dense/bias:0' weight_name = 'decoder/decode_with_shared_attention/decoder/dense/kernel:0' param_names = [bias_name, weight_name] # param_names = [var.name for var in tf.trainable_variables()] model.adapt_weights(param_names) # ------------------------------- # new_save_path = config['save'] if not os.path.exists(new_save_path): os.makedirs(new_save_path) write_config(new_save_path + '/config.txt', config) # save & restore model saver = tf.train.Saver(max_to_keep=1) save_path = config['load'] model_number = config['model_number'] if config[ 'model_number'] != None else config['num_epochs'] - 1 full_save_path_to_model = save_path + '/model-' + str(model_number) with tf.Session(config=sess_config) as sess: # Restore variables from disk. saver.restore(sess, full_save_path_to_model) for epoch in range(10): print("num_batches = ", len(batches)) random.shuffle(batches) for i, batch in enumerate(batches): feed_dict = { model.src_word_ids: batch['src_word_ids'], model.tgt_word_ids: batch['tgt_word_ids'], model.src_sentence_lengths: batch['src_sentence_lengths'], model.tgt_sentence_lengths: batch['tgt_sentence_lengths'], model.dropout: config['dropout'] } _ = sess.run([model.adapt_op], feed_dict=feed_dict) if i % 100 == 0: # to print out training status if config['decoding_method'] != 'beamsearch': [train_loss, infer_loss ] = sess.run([model.train_loss, model.infer_loss], feed_dict=feed_dict) print( "batch: {} --- train_loss: {:.5f} | inf_loss: {:.5f}" .format(i, train_loss, infer_loss)) else: # --- beam search --- # [train_loss] = sess.run([model.train_loss], feed_dict=feed_dict) print("BEAMSEARCH - batch: {} --- train_loss: {:.5f}". format(i, train_loss)) sys.stdout.flush() model.increment_counter() print("################## EPOCH {} done ##################".format( epoch)) saver.save(sess, new_save_path + '/model', global_step=epoch)