def train(): # Load parallel data to train print 'Loading training data..' train_set = BiTextIterator(source=FLAGS.source_train_data, target=FLAGS.target_train_data, source_dict=FLAGS.source_vocabulary, target_dict=FLAGS.target_vocabulary, batch_size=FLAGS.batch_size, maxlen=FLAGS.max_seq_length, n_words_source=FLAGS.num_encoder_symbols, n_words_target=FLAGS.num_decoder_symbols, shuffle_each_epoch=FLAGS.shuffle_each_epoch, sort_by_length=FLAGS.sort_by_length, maxibatch_size=FLAGS.max_load_batches) if FLAGS.source_valid_data and FLAGS.target_valid_data: print 'Loading validation data..' valid_set = BiTextIterator(source=FLAGS.source_valid_data, target=FLAGS.target_valid_data, source_dict=FLAGS.source_vocabulary, target_dict=FLAGS.target_vocabulary, batch_size=FLAGS.batch_size, maxlen=None, n_words_source=FLAGS.num_encoder_symbols, n_words_target=FLAGS.num_decoder_symbols) else: valid_set = None # Initiate TF session with tf.Session(config=tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess: # Create a new model or reload existing checkpoint model = create_model(sess, FLAGS) step_time, loss = 0.0, 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Training loop print 'Training..' for epoch_idx in xrange(FLAGS.max_epochs): if model.global_epoch_step.eval() >= FLAGS.max_epochs: print 'Training is already complete.', \ 'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs) break for source_seq, target_seq in train_set: # Get a batch from training parallel data source, source_len, target, target_len = prepare_train_batch( source_seq, target_seq, FLAGS.max_seq_length) if source is None or target is None: print 'No samples under max_seq_length ', FLAGS.max_seq_length continue # Execute a single training step step_loss = model.train(sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) loss += float(step_loss) / FLAGS.display_freq words_seen += float(np.sum(source_len + target_len)) sents_seen += float(source.shape[0]) # batch_size if model.global_step.eval() % FLAGS.display_freq == 0: avg_perplexity = math.exp( float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / FLAGS.display_freq words_per_sec = words_seen / time_elapsed sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \ 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \ '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec) loss = 0 words_seen = 0 sents_seen = 0 start_time = time.time() # Execute a validation step if valid_set and model.global_step.eval( ) % FLAGS.valid_freq == 0: print 'Validation step' valid_loss = 0.0 valid_sents_seen = 0 for source_seq, target_seq in valid_set: # Get a batch from validation parallel data source, source_len, target, target_len = prepare_train_batch( source_seq, target_seq) # Compute validation loss: average per word cross entropy loss step_loss = model.eval( sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) batch_size = source.shape[0] valid_loss += step_loss * batch_size valid_sents_seen += batch_size print ' {} samples seen'.format(valid_sents_seen) valid_loss = valid_loss / valid_sents_seen print 'Valid perplexity: {0:.2f}'.format( math.exp(valid_loss)) # Save the model checkpoint if model.global_step.eval() % FLAGS.save_freq == 0: print 'Saving the model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump(model.config, open( '%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) # Increase the epoch index of the model model.global_epoch_step_op.eval() print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval()) print 'Saving the last model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump( model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) print 'Training Terminated'
def train(config): # Load parallel data to train print 'Loading training data..' train_set = BiTextIterator(source=config.src_train, target=config.tgt_train, source_dict=config.src_vocab, target_dict=config.tgt_vocab, batch_size=config.batch_size, maxlen=config.max_seq_len, n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols, shuffle_each_epoch=config.shuffle, sort_by_length=config.sort_by_len, maxibatch_size=config.maxi_batches) valid_set = None if config.src_valid and config.tgt_valid: print 'Loading validation data..' valid_set = BiTextIterator(source=config.src_valid, target=config.tgt_valid, source_dict=config.src_vocab, target_dict=config.tgt_vocab, batch_size=config.batch_size, maxlen=None, n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols, shuffle_each_epoch=False, sort_by_length=config.sort_by_len, maxibatch_size=config.maxi_batches) # Create a Quasi-RNN model model, model_state = create_model(config) # Loss and Optimizer criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token) optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) loss = 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Training loop print 'Training..' for epoch_idx in xrange(config.max_epochs): if model_state['epoch'] >= config.max_epochs: print 'Training is already complete.', \ 'current epoch:{}, max epoch:{}'.format(model_state['epoch'], config.max_epochs) break for source_seq, target_seq in train_set: # Get a batch from training parallel data enc_input, enc_len, dec_input, dec_target, dec_len = \ prepare_train_batch(source_seq, target_seq, config.max_seq_len) if enc_input is None or dec_input is None or dec_target is None: print 'No samples under max_seq_length ', config.max_seq_len continue if use_cuda: enc_input = Variable(enc_input.cuda()) enc_len = Variable(enc_len.cuda()) dec_input = Variable(dec_input.cuda()) dec_target = Variable(dec_target.cuda()) dec_len = Variable(dec_len.cuda()) else: enc_input = Variable(enc_input) enc_len = Variable(enc_len) dec_input = Variable(dec_input) dec_target = Variable(dec_target) dec_len = Variable(dec_len) # Execute a single training step optimizer.zero_grad() dec_logits = model(enc_input, enc_len, dec_input) step_loss = criterion(dec_logits, dec_target.view(-1)) step_loss.backward() nn.utils.clip_grad_norm(model.parameters(), config.max_grad_norm) optimizer.step() loss += float(step_loss.data[0]) / config.display_freq words_seen += torch.sum(enc_len + dec_len).data[0] sents_seen += enc_input.size(0) # batch_size model_state['train_steps'] += 1 # Display training status if model_state['train_steps'] % config.display_freq == 0: avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / config.display_freq words_per_sec = words_seen / time_elapsed sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \ 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \ '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec) loss = 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Execute a validation process if valid_set and model_state['train_steps'] % config.valid_freq == 0: model.eval() print 'Validation step' valid_steps = 0 valid_loss = 0.0 valid_sents_seen = 0 for source_seq, target_seq in valid_set: # Get a batch from validation parallel data enc_input, enc_len, dec_input, dec_target, _ = \ prepare_train_batch(source_seq, target_seq) if use_cuda: enc_input = Variable(enc_input.cuda()) enc_len = Variable(enc_len.cuda()) dec_input = Variable(dec_input.cuda()) dec_target = Variable(dec_target.cuda()) else: enc_input = Variable(enc_input) enc_len = Variable(enc_len) dec_input = Variable(dec_input) dec_target = Variable(dec_target) dec_logits = model(enc_input, enc_len, dec_input) step_loss = criterion(dec_logits, dec_target.view(-1)) valid_steps += 1 valid_loss += float(step_loss.data[0]) valid_sents_seen += enc_input.size(0) print ' {} samples seen'.format(valid_sents_seen) model.train() print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps)) # Save the model checkpoint if model_state['train_steps'] % config.save_freq == 0: print 'Saving the model..' model_state['state_dict'] = model.state_dict() # state = dict(list(model_state.items())) model_path = os.path.join(config.model_dir, config.model_name) torch.save(model_state, model_path) # Increase the epoch index of the model model_state['epoch'] += 1 print 'Epoch {0:} DONE'.format(model_state['epoch'])