def train(): config_proto = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=config_proto) as sess: # Build the model config = OrderedDict(sorted(FLAGS.__flags.items())) model = Seq2SeqModel(config, 'train') # Create a log writer object log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph) # Create a saver # Using var_list = None returns the list of all saveable variables saver = tf.train.Saver(var_list=None) # Initiaize global variables or reload existing checkpoint load_or_create_model(sess, model, saver, FLAGS) # Load word2vec embedding embedding = get_word_embedding(FLAGS.hidden_units, alignment=FLAGS.align_word2vec) print(embedding.shape) model.init_vars(sess, embedding=embedding) step_time, loss = 0.0, 0.0 sents_seen = 0 start_time = time.time() print 'Training...' for epoch_idx in xrange(FLAGS.max_epochs): if model.global_epoch_step.eval() >= FLAGS.max_epochs: print 'Training is already complete.', \ 'Current epoch: {}, Max epoch: {}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs) break # Prepare batch training data # TODO(sdsuo): Make corresponding changes in data_utils for source, source_len, target, target_len in gen_batch_train_data( FLAGS.batch_size, prev=FLAGS.prev_data, rev=FLAGS.rev_data, align=FLAGS.align_data, cangtou=FLAGS.cangtou_data): step_loss, summary = model.train( sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) loss += float(step_loss) / FLAGS.display_freq sents_seen += float(source.shape[0]) # batch_size # Display information if model.global_step.eval() % FLAGS.display_freq == 0: avg_perplexity = math.exp( float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / FLAGS.display_freq sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \ 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \ '{0:.2f} sents/s'.format(sents_per_sec) loss = 0 sents_seen = 0 start_time = time.time() # Record training summary for the current batch log_writer.add_summary(summary, model.global_step.eval()) # Save the model checkpoint if model.global_step.eval() % FLAGS.save_freq == 0: print 'Saving the model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, saver, checkpoint_path, global_step=model.global_step) json.dump(model.config, open( '%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) # Increase the epoch index of the model model.increment_global_epoch_step_op.eval() print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval()) print 'Saving the last model' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, saver, checkpoint_path, global_step=model.global_step) json.dump( model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) print 'Training terminated'
def train(): config_proto = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True) # device_count = {'GPU': 0} ) with tf.Session(config=config_proto) as sess: # Build the model config = { 'cangtou_data': False, 'rev_data': True, 'align_data': True, 'prev_data': True, 'align_word2vec': True, 'cell_type': 'lstm', 'attention_type': 'bahdanau', 'hidden_units': 128, 'depth': 4, 'embedding_size': 128, 'num_encoder_symbols': 30000, 'num_decoder_symbols': 30000, 'vocab_size': 6000, 'use_residual': True, 'attn_input_feeding': False, 'use_dropout': True, 'dropout_rate': 0.3, 'learning_rate': 0.0002, 'max_gradient_norm': 1.0, 'batch_size': 64, 'max_epochs': 10000, 'max_load_batches': 20, 'max_seq_length': 50, 'display_freq': 100, 'save_freq': 100, 'valid_freq': 1150000, 'optimizer': 'adam', 'model_dir': 'model', 'summary_dir': 'model/summary', 'model_name': 'translate.ckpt', 'shuffle_each_epoch': True, 'sort_by_length': True, 'use_fp16': False, 'bidirectional': True, 'train_mode': 'ground_truth', 'sampling_probability': 0.1, 'start_token': 0, 'end_token': 5999, 'allow_soft_placement': True, 'log_device_placement': False } model = Seq2SeqModel(config, 'train') # Create a log writer object log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph) # Create a saver # Using var_list = None returns the list of all saveable variables saver = tf.train.Saver(var_list=None) # Initiaize global variables or reload existing checkpoint load_or_create_model(sess, model, saver, FLAGS) # Load word2vec embedding embedding = get_word_embedding(FLAGS.hidden_units, alignment=FLAGS.align_word2vec) model.init_vars(sess, embedding=embedding) step_time, loss = 0.0, 0.0 sents_seen = 0 start_time = time.time() print('Training...') cc = 0 for epoch_idx in range(FLAGS.max_epochs): if model.global_epoch_step.eval() >= FLAGS.max_epochs: print( 'Training is already complete.', 'Current epoch: {}, Max epoch: {}'.format( model.global_epoch_step.eval(), FLAGS.max_epochs)) break # Prepare batch training data # TODO(sdsuo): Make corresponding changes in data_utils for source, source_len, target, target_len in gen_batch_train_data( FLAGS.batch_size, prev=FLAGS.prev_data, rev=FLAGS.rev_data, align=FLAGS.align_data, cangtou=FLAGS.cangtou_data): step_loss, summary = model.train( sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) cc += 1 #print("dvdvshdskjs", cc) loss += float(step_loss) / FLAGS.display_freq sents_seen += float(source.shape[0]) # batch_size # Display information if model.global_step.eval() % 100 == 0: avg_perplexity = math.exp( float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / FLAGS.display_freq sents_per_sec = sents_seen / time_elapsed print('Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, '{0:.2f} sents/s'.format(sents_per_sec)) loss = 0 sents_seen = 0 start_time = time.time() # Record training summary for the current batch log_writer.add_summary(summary, model.global_step.eval()) # Save the model checkpoint if model.global_step.eval() % 100 == 0: print('Saving the model..') checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, saver, checkpoint_path, global_step=model.global_step) json.dump(model.config, open( '%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'w'), indent=2) # Increase the epoch index of the model model.increment_global_epoch_step_op.eval() print('Epoch {0:} DONE'.format(model.global_epoch_step.eval())) print('Saving the last model') checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, saver, checkpoint_path, global_step=model.global_step) json.dump( model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'w'), indent=2) print('Training terminated')