def run_test(sess, model, test_data, verbose=True): predicted_ids = [] alignment_history = [] batch_iter = make_batch_iter(list(zip(*test_data)), config.batch_size, shuffle=False, verbose=verbose) for step, batch in enumerate(batch_iter): value_seq, attr_seq, pos_fw_seq, pos_bw_seq, _ = list(zip(*batch)) src_len_seq = np.array([len(src) for src in value_seq]) value_seq = np.array(pad_batch(value_seq, config.pad_id)) attr_seq = np.array(pad_batch(attr_seq, config.pad_id)) pos_fw_seq = np.array(pad_batch(pos_fw_seq, config.pad_id)) pos_bw_seq = np.array(pad_batch(pos_bw_seq, config.pad_id)) _predicted_ids, _alignment_history = sess.run( [model.predicted_ids, model.alignment_history], feed_dict={ model.value_inp: value_seq, model.attr_inp: attr_seq, model.pos_fw_inp: pos_fw_seq, model.pos_bw_inp: pos_bw_seq, model.src_len: src_len_seq, model.training: False } ) predicted_ids.extend(_predicted_ids.tolist()) alignment_history.extend(np.argmax(_alignment_history, axis=-1).tolist()) if verbose: print('\rprocessing batch: {:>6d}'.format(step + 1), end='') print() return predicted_ids, alignment_history
def make_batch_data(batch): topic, triple, src, tgt = list(zip(*batch)) topic_len = np.array([len(v) for v in topic]) triple_len = np.array([len(v) for v in triple]) src_len = np.array([len(v) for v in src]) tgt_len = np.array([len(v) for v in tgt]) topic = np.array(pad_batch(topic, config.pad_id)) triple = np.array(pad_batch(triple, config.pad_id)) src = np.array(pad_batch(src, config.pad_id)) tgt = np.array(pad_batch(tgt, config.pad_id)) return topic, topic_len, triple, triple_len, src, src_len, tgt, tgt_len
def run_evaluate(sess, model, valid_data, valid_summary_writer=None, verbose=True): steps = 0 predicted_ids = [] alignment_history = [] total_loss = 0.0 total_accu = 0.0 batch_iter = make_batch_iter(list(zip(*valid_data)), config.batch_size, shuffle=False, verbose=verbose) for batch in batch_iter: src_seq, tgt_seq = list(zip(*batch)) src_len_seq = np.array([len(src) for src in src_seq]) tgt_len_seq = np.array([len(tgt) for tgt in tgt_seq]) src_seq = np.array(pad_batch(src_seq, config.pad_id)) tgt_seq = np.array(pad_batch(tgt_seq, config.pad_id)) _predicted_ids, _alignment_history, loss, accu, global_step, summary = sess.run( [ model.predicted_ids, model.alignment_history, model.loss, model.accu, model.global_step, model.summary ], feed_dict={ model.src_inp: src_seq, model.tgt_inp: tgt_seq[:, :-1], # 1 for eos model.tgt_out: tgt_seq[:, 1:], # 1 for sos model.src_len: src_len_seq, model.tgt_len: tgt_len_seq - 1, # 1 for eos model.training: False }) predicted_ids.extend(_predicted_ids.tolist()) if not config.beam_search: alignment_history.extend( np.argmax(_alignment_history, axis=-1).tolist()) steps += 1 total_loss += loss total_accu += accu if verbose: print('\rprocessing batch: {:>6d}'.format(steps + 1), end='') if steps % args.log_steps == 0 and valid_summary_writer is not None: valid_summary_writer.add_summary(summary, global_step) print() return predicted_ids, alignment_history, total_loss / steps, total_accu / steps
def run_train(sess, model, train_data, valid_data, saver, evaluator, train_summary_writer=None, valid_summary_writer=None, verbose=True): flag = 0 train_log = 0.0 global_step = 0 for i in range(config.num_epoch): print_title('Train Epoch: {}'.format(i + 1)) steps = 0 total_loss = 0.0 total_accu = 0.0 batch_iter = make_batch_iter(list(zip(*train_data)), config.batch_size, shuffle=True, verbose=verbose) for batch in batch_iter: start_time = time.time() value_seq, attr_seq, pos_fw_seq, pos_bw_seq, desc_seq = list(zip(*batch)) src_len_seq = np.array([len(src) for src in value_seq]) tgt_len_seq = np.array([len(tgt) for tgt in desc_seq]) value_seq = np.array(pad_batch(value_seq, config.pad_id)) attr_seq = np.array(pad_batch(attr_seq, config.pad_id)) pos_fw_seq = np.array(pad_batch(pos_fw_seq, config.pad_id)) pos_bw_seq = np.array(pad_batch(pos_bw_seq, config.pad_id)) desc_seq = np.array(pad_batch(desc_seq, config.pad_id)) _, loss, accu, global_step, summary = sess.run( [model.train_op, model.loss, model.accu, model.global_step, model.summary], feed_dict={ model.value_inp: value_seq, model.attr_inp: attr_seq, model.pos_fw_inp: pos_fw_seq, model.pos_bw_inp: pos_bw_seq, model.desc_inp: desc_seq[:, :-1], # 1 for eos model.desc_out: desc_seq[:, 1:], # 1 for sos model.src_len: src_len_seq, model.tgt_len: tgt_len_seq - 1, # 1 for eos model.training: True } ) steps += 1 total_loss += loss total_accu += accu if verbose: print('\rafter {:>6d} batch(s), train loss is {:>.4f}, train accuracy is {:>.4f}, {:>.4f}s/batch' .format(steps, loss, accu, time.time() - start_time), end='') if steps % args.log_steps == 0 and train_summary_writer is not None: train_summary_writer.add_summary(summary, global_step) if global_step % args.save_steps == 0: saver.save(sess, config.model_file, global_step=global_step) # evaluate saved models after pre-train epochs if i + 1 > args.pre_train_epochs: predicted_ids, alignment_history, valid_loss, valid_accu = run_evaluate( sess, model, valid_data, valid_summary_writer, verbose=False ) print_title('Valid Result', sep='*') print('average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'.format(valid_loss, valid_accu)) print_title('Saving Result') save_result(predicted_ids, alignment_history, config.id_2_word, config.valid_data_small, config.valid_result) eval_results = evaluator.evaluate(config.valid_data_small, config.valid_result, config.to_lower) # early stop if eval_results['Bleu_4'] > train_log: flag = 0 train_log = eval_results['Bleu_4'] elif flag < 5: flag += 1 elif args.early_stop: return print() print_title('Train Result') print('average train loss: {:>.4f}, average train accuracy: {:>.4f}'.format( total_loss / steps, total_accu / steps)) saver.save(sess, config.model_file, global_step=global_step)
def run_train(sess, model, train_data, valid_data, saver, evaluator, train_summary_writer=None, valid_summary_writer=None, verbose=True): flag = 0 valid_log = 0.0 best_valid_log = 0.0 valid_log_history = {'loss': [], 'accuracy': [], 'global_step': []} global_step = 0 for i in range(config.num_epoch): print_title('Train Epoch: {}'.format(i + 1)) steps = 0 total_loss = 0.0 total_accu = 0.0 batch_iter = make_batch_iter(list(zip(*train_data)), config.batch_size, shuffle=True, verbose=verbose) for batch in batch_iter: start_time = time.time() src_seq, tgt_seq = list(zip(*batch)) src_len_seq = np.array([len(src) for src in src_seq]) tgt_len_seq = np.array([len(tgt) for tgt in tgt_seq]) src_seq = np.array(pad_batch(src_seq, config.pad_id)) tgt_seq = np.array(pad_batch(tgt_seq, config.pad_id)) _, loss, accu, global_step, summary = sess.run( [ model.train_op, model.loss, model.accu, model.global_step, model.summary ], feed_dict={ model.src_inp: src_seq, model.tgt_inp: tgt_seq[:, :-1], # 1 for eos model.tgt_out: tgt_seq[:, 1:], # 1 for sos model.src_len: src_len_seq, model.tgt_len: tgt_len_seq - 1, # 1 for eos model.training: True }) steps += 1 total_loss += loss total_accu += accu if verbose: print( '\rafter {:>6d} batch(s), train loss is {:>.4f}, train accuracy is {:>.4f}, {:>.4f}s/batch' .format(steps, loss, accu, time.time() - start_time), end='') if steps % args.log_steps == 0 and train_summary_writer is not None: train_summary_writer.add_summary(summary, global_step) if global_step % args.save_steps == 0: # evaluate saved models after pre-train epochs if i < args.pre_train_epochs: saver.save(sess, config.model_file, global_step=global_step) else: predicted_ids, alignment_history, valid_loss, valid_accu = run_evaluate( sess, model, valid_data, valid_summary_writer, verbose=False) print_title('Valid Result', sep='*') print( 'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}' .format(valid_loss, valid_accu)) print_title('Saving Result') if not config.beam_search: save_result_v1(predicted_ids, alignment_history, config.id_2_word, config.valid_data, config.valid_result) else: save_result_v2(predicted_ids, config.id_2_word, config.valid_result) valid_results = evaluator.evaluate(config.valid_data, config.valid_result, config.to_lower) if valid_results['Bleu_4'] >= best_valid_log: best_valid_log = valid_results['Bleu_4'] saver.save(sess, config.model_file, global_step=global_step) # early stop if valid_results[ 'Bleu_4'] - args.early_stop_delta >= valid_log: flag = 0 elif flag < args.early_stop: flag += 1 elif args.early_stop: return valid_log_history valid_log = valid_results['Bleu_4'] valid_log_history['loss'].append(valid_loss) valid_log_history['accuracy'].append(valid_accu) valid_log_history['global_step'].append(int(global_step)) print() print_title('Train Result') print('average train loss: {:>.4f}, average train accuracy: {:>.4f}'. format(total_loss / steps, total_accu / steps)) saver.save(sess, config.model_file, global_step=global_step) return valid_log_history