def train(): # Load parallel data to train print 'Loading training data..' train_set = SegTextIterator(source=FLAGS.source_train_data, time_step=FLAGS.time_step, batch_size=FLAGS.batch_size, set_type=0, test_set_partition=FLAGS.test_set_partition, balance=FLAGS.data_balance, noise=FLAGS.noise) if FLAGS.source_valid_data: print 'Loading validation data..' valid_set = SegTextIterator(source=FLAGS.source_valid_data, time_step=FLAGS.time_step, batch_size=FLAGS.batch_size, set_type=1, test_set_partition=FLAGS.test_set_partition, balance=FLAGS.data_balance, noise=FLAGS.noise) else: valid_set = None # Initiate TF session with tf.Session(config=tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess: # Create a log writer object log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph) # Create a new model or reload existing checkpoint model = create_model(sess, FLAGS) step_time, loss = 0.0, 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Training loop print 'Training..' for epoch_idx in xrange(FLAGS.max_epochs): if model.global_epoch_step.eval() >= FLAGS.max_epochs: print 'Training is already complete.', \ 'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs) break for source_data, source_word, source_char, source_word_mask, source_char_mask, target_seq, true_batch_size in train_set: source_data, source_len, target, target_len = prepare_train_batch(source_data, target_seq) source = source_data, source_word, source_word_mask, source_char, source_char_mask #print source_char #print "====" #print source_char_mask #print source_len #print "====" #print target #print target_len #print type(source_char), type(source_char_mask), type(source_len), type(target_len) #exit() if source is None or target is None: print 'No samples under max_seq_length ', FLAGS.max_seq_length continue # Execute a single training step step_loss, summary = model.train(sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) loss += float(step_loss) / FLAGS.display_freq words_seen += float(np.sum(source_len+target_len)) sents_seen += float(FLAGS.batch_size) # batch_size if model.global_step.eval() % FLAGS.display_freq == 0: avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / FLAGS.display_freq words_per_sec = words_seen / time_elapsed sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \ 'Perplexity {0:.5f}'.format(avg_perplexity), ' Step-time ', step_time, \ '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec) loss = 0 words_seen = 0 sents_seen = 0 start_time = time.time() # Record training summary for the current batch log_writer.add_summary(summary, model.global_step.eval()) # Execute a validation step if valid_set and model.global_step.eval() % FLAGS.valid_freq == 0: print 'Validation step' valid_loss = 0.0 valid_sents_seen = 0 for source_seq, target_seq in valid_set: # Get a batch from validation parallel data source, source_len, target, target_len = prepare_train_batch(source_seq, target_seq) # Compute validation loss: average per word cross entropy loss step_loss, summary = model.eval(sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) batch_size = source.shape[0] valid_loss += step_loss * batch_size valid_sents_seen += batch_size print ' {} samples seen'.format(valid_sents_seen) valid_loss = valid_loss / valid_sents_seen print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss)) # Save the model checkpoint if model.global_step.eval() % FLAGS.save_freq == 0: print 'Saving the model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump(model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) # Increase the epoch index of the model model.global_epoch_step_op.eval() print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval()) print 'Saving the last model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump(model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) print 'Training Terminated'
def train(config): # Load parallel data to train print 'Loading training data..' train_set = BiTextIterator(source=config.src_train, target=config.tgt_train, source_dict=config.src_vocab, target_dict=config.tgt_vocab, batch_size=config.batch_size, maxlen=config.max_seq_len, n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols, shuffle_each_epoch=config.shuffle, sort_by_length=config.sort_by_len, maxibatch_size=config.maxi_batches) valid_set = None if config.src_valid and config.tgt_valid: print 'Loading validation data..' valid_set = BiTextIterator(source=config.src_valid, target=config.tgt_valid, source_dict=config.src_vocab, target_dict=config.tgt_vocab, batch_size=config.batch_size, maxlen=None, n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols, shuffle_each_epoch=False, sort_by_length=config.sort_by_len, maxibatch_size=config.maxi_batches) # Create a Quasi-RNN model model, model_state = create_model(config) # Loss and Optimizer criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token) optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) loss = 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Training loop print 'Training..' for epoch_idx in xrange(config.max_epochs): if model_state['epoch'] >= config.max_epochs: print 'Training is already complete.', \ 'current epoch:{}, max epoch:{}'.format(model_state['epoch'], config.max_epochs) break for source_seq, target_seq in train_set: # Get a batch from training parallel data enc_input, enc_len, dec_input, dec_target, dec_len = \ prepare_train_batch(source_seq, target_seq, config.max_seq_len) if enc_input is None or dec_input is None or dec_target is None: print 'No samples under max_seq_length ', config.max_seq_len continue if use_cuda: enc_input = Variable(enc_input.cuda()) enc_len = Variable(enc_len.cuda()) dec_input = Variable(dec_input.cuda()) dec_target = Variable(dec_target.cuda()) dec_len = Variable(dec_len.cuda()) else: enc_input = Variable(enc_input) enc_len = Variable(enc_len) dec_input = Variable(dec_input) dec_target = Variable(dec_target) dec_len = Variable(dec_len) # Execute a single training step optimizer.zero_grad() dec_logits = model(enc_input, enc_len, dec_input) step_loss = criterion(dec_logits, dec_target.view(-1)) step_loss.backward() nn.utils.clip_grad_norm(model.parameters(), config.max_grad_norm) optimizer.step() loss += float(step_loss.data[0]) / config.display_freq words_seen += torch.sum(enc_len + dec_len).data[0] sents_seen += enc_input.size(0) # batch_size model_state['train_steps'] += 1 # Display training status if model_state['train_steps'] % config.display_freq == 0: avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / config.display_freq words_per_sec = words_seen / time_elapsed sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \ 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \ '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec) loss = 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Execute a validation process if valid_set and model_state['train_steps'] % config.valid_freq == 0: model.eval() print 'Validation step' valid_steps = 0 valid_loss = 0.0 valid_sents_seen = 0 for source_seq, target_seq in valid_set: # Get a batch from validation parallel data enc_input, enc_len, dec_input, dec_target, _ = \ prepare_train_batch(source_seq, target_seq) if use_cuda: enc_input = Variable(enc_input.cuda()) enc_len = Variable(enc_len.cuda()) dec_input = Variable(dec_input.cuda()) dec_target = Variable(dec_target.cuda()) else: enc_input = Variable(enc_input) enc_len = Variable(enc_len) dec_input = Variable(dec_input) dec_target = Variable(dec_target) dec_logits = model(enc_input, enc_len, dec_input) step_loss = criterion(dec_logits, dec_target.view(-1)) valid_steps += 1 valid_loss += float(step_loss.data[0]) valid_sents_seen += enc_input.size(0) print ' {} samples seen'.format(valid_sents_seen) model.train() print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps)) # Save the model checkpoint if model_state['train_steps'] % config.save_freq == 0: print 'Saving the model..' model_state['state_dict'] = model.state_dict() # state = dict(list(model_state.items())) model_path = os.path.join(config.model_dir, config.model_name) torch.save(model_state, model_path) # Increase the epoch index of the model model_state['epoch'] += 1 print 'Epoch {0:} DONE'.format(model_state['epoch'])
def train(): # Load parallel data to train print 'Loading training data..' train_set = BiTextIterator(source=FLAGS.source_train_data, target=FLAGS.target_train_data, source_dict=FLAGS.source_vocabulary, target_dict=FLAGS.target_vocabulary, batch_size=FLAGS.batch_size, maxlen=FLAGS.max_seq_length, n_words_source=FLAGS.num_encoder_symbols, n_words_target=FLAGS.num_decoder_symbols, shuffle_each_epoch=FLAGS.shuffle_each_epoch, sort_by_length=FLAGS.sort_by_length, maxibatch_size=FLAGS.max_load_batches) if FLAGS.source_valid_data and FLAGS.target_valid_data: print 'Loading validation data..' valid_set = BiTextIterator(source=FLAGS.source_valid_data, target=FLAGS.target_valid_data, source_dict=FLAGS.source_vocabulary, target_dict=FLAGS.target_vocabulary, batch_size=FLAGS.batch_size, maxlen=None, n_words_source=FLAGS.num_encoder_symbols, n_words_target=FLAGS.num_decoder_symbols) else: valid_set = None # Initiate TF session with tf.Session(config=tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess: # Create a new model or reload existing checkpoint model = create_model(sess, FLAGS) step_time, loss = 0.0, 0.0 words_seen, sents_seen = 0, 0 start_time = time.time() # Training loop print 'Training..' for epoch_idx in xrange(FLAGS.max_epochs): if model.global_epoch_step.eval() >= FLAGS.max_epochs: print 'Training is already complete.', \ 'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs) break for source_seq, target_seq in train_set: # Get a batch from training parallel data source, source_len, target, target_len = prepare_train_batch( source_seq, target_seq, FLAGS.max_seq_length) if source is None or target is None: print 'No samples under max_seq_length ', FLAGS.max_seq_length continue # Execute a single training step step_loss = model.train(sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) loss += float(step_loss) / FLAGS.display_freq words_seen += float(np.sum(source_len + target_len)) sents_seen += float(source.shape[0]) # batch_size if model.global_step.eval() % FLAGS.display_freq == 0: avg_perplexity = math.exp( float(loss)) if loss < 300 else float("inf") time_elapsed = time.time() - start_time step_time = time_elapsed / FLAGS.display_freq words_per_sec = words_seen / time_elapsed sents_per_sec = sents_seen / time_elapsed print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \ 'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \ '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec) loss = 0 words_seen = 0 sents_seen = 0 start_time = time.time() # Execute a validation step if valid_set and model.global_step.eval( ) % FLAGS.valid_freq == 0: print 'Validation step' valid_loss = 0.0 valid_sents_seen = 0 for source_seq, target_seq in valid_set: # Get a batch from validation parallel data source, source_len, target, target_len = prepare_train_batch( source_seq, target_seq) # Compute validation loss: average per word cross entropy loss step_loss = model.eval( sess, encoder_inputs=source, encoder_inputs_length=source_len, decoder_inputs=target, decoder_inputs_length=target_len) batch_size = source.shape[0] valid_loss += step_loss * batch_size valid_sents_seen += batch_size print ' {} samples seen'.format(valid_sents_seen) valid_loss = valid_loss / valid_sents_seen print 'Valid perplexity: {0:.2f}'.format( math.exp(valid_loss)) # Save the model checkpoint if model.global_step.eval() % FLAGS.save_freq == 0: print 'Saving the model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump(model.config, open( '%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) # Increase the epoch index of the model model.global_epoch_step_op.eval() print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval()) print 'Saving the last model..' checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) model.save(sess, checkpoint_path, global_step=model.global_step) json.dump( model.config, open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'), indent=2) print 'Training Terminated'
def decode(): # Load model config _data_word = None _alignment = None _predict = np.array([]) _ground_truth = np.array([]) config = load_config(FLAGS) # Load source data to decode #test_set = TextIterator(source=config['decode_input'], #batch_size=config['decode_batch_size'], #source_dict=config['source_vocabulary'], #maxlen=None, #n_words_source=config['num_encoder_symbols']) test_set = SegTextIterator(source=config['decode_input'], time_step=config['time_step'], batch_size=config['batch_size'], set_type=1, test_set_partition=config["test_set_partition"], balance=config["data_balance"], noise=config["noise"]) # Load inverse dictionary used in decoding #target_inverse_dict = data_utils.load_inverse_dict(config['target_vocabulary']) # Initiate TF session with tf.Session(config=tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess: # Reload existing checkpoint model = load_model(sess, config) try: print 'Decoding {}..'.format(FLAGS.decode_input) #if FLAGS.write_n_best: # fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \ # for k in range(FLAGS.beam_width)] #else: # fout = [data_utils.fopen(FLAGS.decode_output, 'w')] step = 0 set_length = len(test_set) for source_data, source_word, source_char, source_word_mask, source_char_mask, target_seq, true_batch_size in test_set: source_data, source_len, target, target_len = prepare_train_batch( source_data, target_seq) source = source_data, source_word, source_word_mask, source_char, source_char_mask #source, source_len = prepare_batch(source_seq) # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1] # BeamSearchDecoder; [batch_size, max_time_step, beam_width] predicted_ids, alignment, _ = model.predict( sess, encoder_inputs=source, encoder_inputs_length=source_len) a = _[10:1010, :] for i in a: t = [] for j in i: t.append(str(float(j))) print '\t'.join(t) exit() if config['use_attention']: try: _data_word = np.concatenate((_data_word, source_word)) except: _data_word = source_word try: _alignment = np.concatenate((_alignment, alignment)) except: _alignment = alignment _predict = np.concatenate((_predict, predicted_ids[:, 0, 0])) _ground_truth = np.concatenate((_ground_truth, target[:, 0])) progress = 100 * float(step) / set_length print "progress %.1f%%" % progress step += 1 # Write decoding results #for k, f in reversed(list(enumerate(fout))): # for seq in predicted_ids: # f.write(str(data_utils.seq2words(seq[:,k], target_inverse_dict)) + '\n') # if not FLAGS.write_n_best: # break #print ' {}th line decoded'.format(idx * FLAGS.decode_batch_size) if config['use_attention']: idxs = [j for j in range(len(_alignment))] for idx in idxs: print ">>>>" print_group_idx( [_alignment[idx], word_seq_list(_data_word[idx])]) #print_idx(word_seq_list(_data_word[idx])) #print_idx(_alignment[idx]) print _predict[idx], _ground_truth[idx] evaluate(_predict, _ground_truth) print 'Decoding terminated' except IOError: pass