def generate_keyword_mask(self, context_ids): """Generate mask to only keep the related keywords' Q-value. """ self.kg_adjacency_matrix = tf.cond( pred=tf.equal(tx.global_mode(), tf.estimator.ModeKeys.TRAIN), true_fn=lambda: self.train_kg_adjacency_matrix, false_fn=lambda: tf.cond( pred=tf.equal(tx.global_mode(), tf.estimator.ModeKeys.EVAL), true_fn=lambda: self.valid_kg_adjacency_matrix, false_fn=lambda: self.test_kg_adjacency_matrix)) context_ids = tf.cast(context_ids, tf.int64) # shape of adj_matrix_context_ids: [_cur_keywords_len,] adj_matrix_context_ids = self.map_vocab_ids_to_adj_matrix_ids( context_ids) # shape of context_related_adj_matrix: [#adj_matrix_context_ids, adj_matrix_size] context_related_adj_matrix = tf.gather(self.kg_adjacency_matrix, adj_matrix_context_ids) # shape of keyword_mask: [adj_matrix_size,] keyword_mask = tf.reduce_max(context_related_adj_matrix, axis=0) num_related_keywords = tf.reduce_sum( tf.cast(tf.equal(keyword_mask, 1.), tf.float32)) no_related_keywords = tf.equal(num_related_keywords, 0.) ones_tensor = tf.ones(shape=[self.adj_matrix_size]) # if no related keywords for current context, keep all keywords' Q-value keyword_mask = tf.cond(pred=no_related_keywords, true_fn=lambda: ones_tensor, false_fn=lambda: keyword_mask) # remove the <PAD> dimension keyword_mask = keyword_mask[1:] return keyword_mask
def _train_epoch(sess, initial=False): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.restart_dataset(sess, 'train') while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) # (2) Optimize loss feed_dict = { #x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } fetches = { 'train_op': train_op, 'step': global_step, } fetches.update(loss_dict) rets = sess.run(fetches, feed_dict) step = rets['step'] dis_steps = config_train.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: _log_losses(rets, step) eval_steps = config_train.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _dev_epoch(sess) sample_steps = config_train.sample_steps if _is_head() and sample_steps > 0 and step % sample_steps == 0: print('-----------testing-----------------') _test_epoch(sess, step=step) ckpt_steps = config_train.checkpoint_steps if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0: ckpt_fn = os.path.join(output_dir, 'model.ckpt') ckpt_fn = saver.save(sess, ckpt_fn, global_step=step) _log('Checkpoint to {}'.format(ckpt_fn)) except tf.errors.OutOfRangeError: break
def _eval_epoch(sess, epoch, mode): references, hypotheses = [], [] bsize = test_batch_size fetches = { 'inferred_ids': inferred_ids, } bno=0 while True: #print("Temp",temp) try: print("Batch",bno) feed_dict = { iterator.handle: iterator.get_handle(sess, 'eval'), tx.global_mode(): tf.estimator.ModeKeys.EVAL, } op = sess.run([batch],feed_dict) feed_dict = { src_input_ids:op[0]['src_input_ids'], src_segment_ids : op[0]['src_segment_ids'], tx.global_mode(): tf.estimator.ModeKeys.EVAL } fetches_ = sess.run(fetches, feed_dict=feed_dict) labels = op[0]['tgt_labels'] hypotheses.extend(h.tolist() for h in fetches_['inferred_ids']) references.extend(r.tolist() for r in labels) hypotheses = utils.list_strip_eos(hypotheses, eos_token_id) references = utils.list_strip_eos(references, eos_token_id) bno = bno+1 except tf.errors.OutOfRangeError: break if mode == 'eval': # Writes results to files to evaluate BLEU # For 'eval' mode, the BLEU is based on token ids (rather than # text tokens) and serves only as a surrogate metric to monitor # the training process fname = os.path.join(model_dir, 'tmp.eval') hypotheses = tx.utils.str_join(hypotheses) references = tx.utils.str_join(references) hyp_fn, ref_fn = tx.utils.write_paired_text( hypotheses, references, fname, mode='s') eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True) eval_bleu = 100. * eval_bleu logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu) print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu)) if eval_bleu > best_results['score']: logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu) best_results['score'] = eval_bleu best_results['epoch'] = epoch model_path = os.path.join(model_dir, 'best-model.ckpt') logger.info('saving model to %s', model_path) print('saving model to %s' % model_path) saver.save(sess, model_path)
def train(self): batch = self.iterator.get_next() loss_t, acc_t, _ = self.predict_keywords(batch) kw_saver = tf.train.Saver() loss, acc, _ = self.forward(batch) retrieval_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op(loss, global_step=retrieval_step, hparams=self.config.opt_hparams) max_val_acc, stopping_flag = 0, 0 with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) kw_saver.restore(sess, self.config._kernel_save_path) saver = tf.train.Saver() for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc, cnt_kwacc = [], [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_, acc_kw = sess.run([train_op, acc, acc_t], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(acc_kw) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}, kw_acc1={}'. format(cur_step, loss, np.mean(cnt_acc[-200:]), np.mean(cnt_kwacc[-200:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc, cnt_kwacc = [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_, acc_kw = sess.run([acc, acc_t], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(acc_kw) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}, kw_acc1={}'.format( mean_acc, np.mean(cnt_kwacc))) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.config._save_path) else: stopping_flag += 1 break if stopping_flag >= self.config._early_stopping: break
def _train_epoch(sess, epoch, step, smry_writer): fetches = { 'step': global_step, 'train_op': train_op, 'smry': summary_merged, 'mle_loss': mle_loss, 'type_cls_loss': type_cls_loss, 'conn_cls_loss': conn_cls_loss, 'total_loss': total_loss, } print("------ Epoch number", epoch + 1, "out of", total_epochs, "epochs ------") for train_batch in range(num_train_batches_in_epoch): try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } op = sess.run([batch], feed_dict) feed_dict = { src_input_ids: op[0]['src_input_ids'], src_segment_ids: op[0]['src_segment_ids'], tgt_input_ids: op[0]['tgt_input_ids'], labels: op[0]['tgt_labels'], type_label: op[0]['type_label'], conn_label: op[0]['conn_label'], learning_rate: utils.get_lr(step, lr), tx.global_mode(): tf.estimator.ModeKeys.TRAIN } fetches_ = sess.run(fetches, feed_dict=feed_dict) step, m_loss, t_loss, c_loss = fetches_['step'], fetches_['mle_loss'], fetches_['type_cls_loss'], \ fetches_['conn_cls_loss'] if step and step % display_steps == 0: logger.info('batch: %d/%d, mle_loss: %.4f, type_cls_loss: %.4f, conn_cls_loss: %.4f', train_batch, num_train_batches_in_epoch, m_loss, t_loss, c_loss) print('batch: %d/%d, mle_loss: %.4f, type_cls_loss: %.4f, conn_cls_loss: %.4f' % (train_batch+1, num_train_batches_in_epoch, m_loss, t_loss, c_loss)) smry_writer.add_summary(fetches_['smry'], global_step=step) except tf.errors.OutOfRangeError: break model_path = model_dir + "/model_" + str(step) + ".ckpt" logger.info('saving model to %s', model_path) print('saving model to %s' % model_path) saver.save(sess, model_path) print("---EVAL---") _eval_epoch(sess, epoch, mode='eval') return step
def train(self): batch = self.iterator.get_next() kw_loss, kw_acc, _ = self.predict_keywords(batch) kw_saver = tf.train.Saver() loss, acc, rank = self.forward_response_retrieval(batch) op_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op( loss, global_step=op_step, hparams=self.model_config._retrieval_opt_hparams) max_val_acc = 0. with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) kw_saver.restore(sess, self.model_config._kp_save_path) saver = tf.train.Saver() for epoch_id in range(self.model_config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: logs_loss_acc = 'batch {}, loss={}, acc1={}'.format( cur_step, loss, np.mean(cnt_acc[-200:])) add_log(self.logs_save_path, logs_loss_acc) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc, cnt_kwacc = [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(kw_acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) logs_loss_acc = 'epoch_id {}, valid acc1={}, kw_acc1={}'.format( epoch_id + 1, mean_acc, np.mean(cnt_kwacc)) add_log(self.logs_save_path, logs_loss_acc) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.model_config._retrieval_save_path) break
def train_keywords(self): batch = self.iterator.get_next() acc = self.predict_keywords(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) self.iterator.switch_to_train_data(sess) batchid = 0 while True: try: batchid += 1 if batchid % 200 == 0: print(batchid) feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} source_keywords, target_keywords = sess.run([ batch['context_text_ids'], batch['keywords_text_ids'] ], feed_dict=feed) for i in range(len(source_keywords)): for skw_id in source_keywords[i]: if skw_id == 0: break for tkw_id in target_keywords[i]: if skw_id >= 3 and tkw_id >= 3: tkw = self.config._vocab[tkw_id - 4] if tkw in self.data_config._keywords_candi: tkw_id = self.data_config._keywords_dict[ tkw] self.pmi_matrix[skw_id][tkw_id] += 1 except tf.errors.OutOfRangeError: break self.pmi_matrix += 0.5 self.pmi_matrix = self.pmi_matrix / ( np.sum(self.pmi_matrix, axis=0) + 1) with open(self.config._matrix_save_path, 'wb') as f: pickle.dump(self.pmi_matrix, f) self.pmi_matrix = tf.convert_to_tensor(self.pmi_matrix, dtype=tf.float32) self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run(acc, feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: print('valid acc1={}'.format(np.mean(cnt_acc))) break
def _dev_epoch(sess): """Evaluates on the dev set. """ iterator.restart_dataset(sess, 'dev') results = tx.utils.AverageRecorder() nsamples = 0 fetches = {} fetches.update(loss_dict) # i = 0 while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'dev'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) # (2) eval loss feed_dict = { #x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets = sess.run(fetches, feed_dict) results.add(rets, weight=rets_data['batch_size']) nsamples += rets_data['batch_size'] except tf.errors.OutOfRangeError: break _log_losses(results.avg()) _log('nsamples: %d' % nsamples) avg_loss = results.avg('loss') if FLAGS.do_train and avg_loss < dev_best['loss']: dev_best.update(results.avg()) ckpt_fn = os.path.join(output_dir, 'model_best.ckpt') ckpt_fn = saver_best.save(sess, ckpt_fn) _log('Checkpoint best to {}'.format(ckpt_fn))
def train_keywords(self): batch = self.iterator.get_next() loss, acc = self.predict_keywords(batch) op_step = tf.Variable(0, name='op_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.kernel_opt_hparams) max_val_acc, stopping_flag = 0, 0 self.saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss_, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 100 == 0: print('batch {}, loss={}, acc1={}'.format( cur_step, loss_, np.mean(cnt_acc[-100:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run(acc, feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) if mean_acc > max_val_acc: max_val_acc = mean_acc self.saver.save(sess, self.config._kernel_save_path) else: stopping_flag += 1 print('epoch_id {}, valid acc1={}'.format( epoch_id + 1, mean_acc)) break if stopping_flag >= self.config._early_stopping: break
def _train_epoch(sess, epoch, step, smry_writer): fetches = { 'step': global_step, 'train_op': train_op, 'smry': summary_merged, 'loss': mle_loss, } while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } op = sess.run([batch], feed_dict) feed_dict = { src_input_ids: op[0]['src_input_ids'], src_segment_ids: op[0]['src_segment_ids'], tgt_input_ids: op[0]['tgt_input_ids'], labels: op[0]['tgt_labels'], learning_rate: utils.get_lr(step, lr), tx.global_mode(): tf.estimator.ModeKeys.TRAIN } fetches_ = sess.run(fetches, feed_dict=feed_dict) step, loss = fetches_['step'], fetches_['loss'] if step and step % display_steps == 0: logger.info('step: %d, loss: %.4f', step, loss) print('step: %d, loss: %.4f' % (step, loss)) smry_writer.add_summary(fetches_['smry'], global_step=step) if step and step % checkpoint_steps == 0: model_path = model_dir + "/model_" + str(step) + ".ckpt" logger.info('saving model to %s', model_path) print('saving model to %s' % model_path) saver.save(sess, model_path) if step > 40000 and step % eval_steps == 0: _eval_epoch(sess, epoch, mode='eval') if step and step <= 40000 and step % (test_steps * 2) == 0: _eval_epoch(sess, epoch, mode='test') if step > 40000 and step % test_steps == 0: _eval_epoch(sess, epoch, mode='test') except tf.errors.OutOfRangeError: break return step
def _train_epoch(sess): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.restart_dataset(sess, 'train') fetches = {'loss': train_op, 'step': global_step} while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } rets = sess.run(fetches, feed_dict) step = rets['step'] dis_steps = config_train.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: tf.logging.info('step:%d; loss:%f' % (step, rets['loss'])) eval_steps = config_train.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _dev_epoch(sess) ckpt_steps = config_train.checkpoint_steps if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0: ckpt_fn = os.path.join(FLAGS.output_dir, 'model.ckpt') ckpt_fn = saver.save(sess, ckpt_fn, global_step=step) tf.logging.info('Checkpoint to {}'.format(ckpt_fn)) except tf.errors.OutOfRangeError: break
def _run(sess, mode): fetches = { 'accu': accu, 'batch_size': batch_size, 'step': global_step, 'loss': loss, } if mode == 'train': fetches['train_op'] = train_op while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } rets = sess.run(fetches, feed_dict) if rets['step'] % 50 == 0: tf.logging.info('step:%d loss:%f' % (rets['step'], rets['loss'])) if rets['step'] == num_train_steps: break except tf.errors.OutOfRangeError: break if mode == 'eval': cum_acc = 0.0 nsamples = 0 while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'eval'), tx.context.global_mode(): tf.estimator.ModeKeys.EVAL, } rets = sess.run(fetches, feed_dict) cum_acc += rets['accu'] * rets['batch_size'] nsamples += rets['batch_size'] except tf.errors.OutOfRangeError: break tf.logging.info('dev accu: {}'.format(cum_acc / nsamples)) if mode == 'test': _all_preds = [] while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'test'), tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, } _preds = sess.run(preds, feed_dict=feed_dict) _all_preds.extend(_preds.tolist()) except tf.errors.OutOfRangeError: break output_file = os.path.join(FLAGS.output_dir, "test_results.tsv") with tf.gfile.GFile(output_file, "w") as writer: writer.write('\n'.join(str(p) for p in _all_preds))
def _train_qnet(self, feed_dict): minibatch = self._replay_memory.get(self._sample_batch_size) observ_batch = np.array([data['observ'] for data in minibatch]) action_batch = np.array([data['action'] for data in minibatch]) reward_batch = np.array([data['reward'] for data in minibatch]) terminal_batch = np.array([data['terminal'] for data in minibatch]) next_observ_batch = \ np.array([data['next_observ'] for data in minibatch]) target_qvalue = self._sess.run(self._target_outputs['qvalues'], feed_dict={ self._observ_inputs: next_observ_batch, tx.global_mode(): tf.estimator.ModeKeys.PREDICT }) y_batch = reward_batch for i in range(self._sample_batch_size): if not terminal_batch[i]: y_batch[i] += self._discount_factor * np.max(target_qvalue[i]) feed_dict_ = { self._observ_inputs: observ_batch, self._y_inputs: y_batch, self._action_inputs: action_batch } feed_dict_.update(feed_dict or {}) self._sess.run(self._train_op, feed_dict=feed_dict_) self._update_target(feed_dict)
def _eval_epoch(sess, mode): """`mode` is one of {'val', 'test'} """ iterator.restart_dataset(sess, mode) refs, hypos = [], [] while True: try: fetches = [ batch['target_text'][:, 1:], infer_outputs.predicted_ids[:, :, 0] ] feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.PREDICT, iterator.handle: iterator.get_handle(sess, mode) } target_texts, output_ids = \ sess.run(fetches, feed_dict=feed_dict) target_texts = tx.utils.strip_special_tokens(target_texts) output_texts = tx.utils.map_ids_to_strs( ids=output_ids, vocab=val_data.target_vocab) for hypo, ref in zip(output_texts, target_texts): hypos.append(hypo) refs.append([ref]) except tf.errors.OutOfRangeError: break return tx.evals.corpus_bleu_moses(list_of_references=refs, hypotheses=hypos)
def _qvalues_from_target(self, observ): return self._sess.run(self._target_outputs['qvalues'], feed_dict={ self._observ_inputs: np.array([observ]), tx.global_mode(): tf.estimator.ModeKeys.PREDICT })
def _train_epoch(sess, summary_writer, mode, train_op, summary_op, mle_outputs): print('in _train_epoch') data_iterator.restart_dataset(sess, mode) feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.TRAIN, data_iterator.handle: data_iterator.get_handle(sess, mode), } cnt = 0 while True: try: loss, summary, tf_outputs, ground_truth, _gamma = \ sess.run((train_op, summary_op, mle_outputs, tgt_y_ids, gamma), feed_dict) step = tf.train.global_step(sess, global_step) print('step {:d}: loss = {:.6f} gamma = {:.4f}'.format( step, loss, _gamma)) summary_writer.add_summary(summary, step) # if step % config_train.steps_per_eval == 0: # _eval_epoch(sess, summary_writer, 'val') # # _eval_epoch(sess, summary_writer, 'test') # if step > 921 and (step % 100 == 0): # _eval_epoch(sess, summary_writer, 'val') except tf.errors.OutOfRangeError: break print('end _train_epoch')
def _run_epoch(sess, data_iter, epoch, is_train=False): loss = 0. iters = 0 fetches = {"mle_loss": mle_loss} if is_train: fetches["train_op"] = train_op mode = (tf.estimator.ModeKeys.TRAIN if is_train else tf.estimator.ModeKeys.EVAL) for _, (x, y) in enumerate(data_iter): batch_size = x.shape[0] feed_dict = { inputs: x, targets: y, learning_rate: lr, tx.global_mode(): mode, } rets = sess.run(fetches, feed_dict) loss += rets["mle_loss"] iters += batch_size ppl = np.exp(loss / iters) return ppl
def _get_align(sess, mode): print('in _get_align') data_iterator.restart_dataset(sess, mode) feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.EVAL, data_iterator.handle: data_iterator.get_handle(sess, mode), } with open('align.pkl', 'wb') as out_file: while True: try: batch = sess.run(data_batch, feed_dict) sd_texts, sent_texts = ( [batch['{}{}_text'.format(field, ref_strs[1])] for field in fields] for fields in (sd_fields, sent_fields)) aligns = batch_get_align(*(sd_texts + sent_texts)) sd_texts, sent_texts = ( [batch_strip_special_tokens_of_list(texts) for texts, field in zip(all_texts, fields)] for all_texts, fields in zip( (sd_texts, sent_texts), (sd_fields, sent_fields))) if FLAGS.verbose: batch_print_align(*(sd_texts + sent_texts + [aligns])) for align in aligns: pickle.dump(align, out_file) except tf.errors.OutOfRangeError: break print('end _get_align')
def _train_epoch(sess): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.restart_dataset(sess, 'train') fetches = { 'train_op': train_op, 'loss': loss, 'batch_size': batch_size, 'step': global_step } while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } rets = sess.run(fetches, feed_dict) step = rets['step'] dis_steps = config_data.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: tf.logging.info('step:%d; loss:%f;' % (step, rets['loss'])) eval_steps = config_data.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _eval_epoch(sess) except tf.errors.OutOfRangeError: break
def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i] == labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0, 0, 0, 0, 0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank + 1) print( 'test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}' .format(rec[0] / len(rank_cnt), rec[2] / len(rank_cnt), rec[4] / len(rank_cnt), MRR / len(rank_cnt))) break
def _train_epoch(sess, summary_writer, mode, train_ops, summary_ops, names): print('in _train_epoch') data_iterator.restart_dataset(sess, mode) feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.TRAIN, data_iterator.handle: data_iterator.get_handle(sess, mode), } while True: try: losses, summaries = sess.run((train_ops, summary_ops), feed_dict) step = tf.train.global_step(sess, global_step) print('step {:d}:\t{}'.format( step, '\t'.join('{}: {:.6f}'.format(name, losses[name]) for name in names))) for summary in summaries.values(): summary_writer.add_summary(summary, step) if step % config_train.steps_per_eval == 0: _eval_epoch(sess, summary_writer, 'val') except tf.errors.OutOfRangeError: break print('end _train_epoch')
def train_epoch(self, sess, summary_writer, mode, train_op, summary_op): print("in _train_epoch") self.data_iterator.restart_dataset(sess, mode) feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.TRAIN, self.data_iterator.handle: self.data_iterator.get_handle(sess, mode), } while True: try: loss, summary = sess.run((train_op, summary_op), feed_dict) step = tf.train.global_step(sess, self.global_step) print("step {:d}: loss = {:.6f}".format(step, loss)) summary_writer.add_summary(summary, step) # if step % config_train.steps_per_eval == 0: # _eval_epoch(sess, summary_writer, 'val') except tf.errors.OutOfRangeError: break print("end _train_epoch")
def test_keywords(self): batch = self.iterator.get_next() loss, acc, kws = self.predict_keywords(batch) saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver.restore(sess, self.config._neural_save_path) self.iterator.switch_to_test_data(sess) cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} acc_, kw_ans, kw_labels = sess.run( [acc, kws, batch['keywords_text_ids']], feed_dict=feed) cnt_acc.append(acc_) rec = [0, 0, 0, 0, 0] sum_kws = 0 for i in range(len(kw_ans)): sum_kws += sum(kw_labels[i] > 3) for j in range(5): if kw_ans[i][j] in kw_labels[i]: for k in range(j, 5): rec[k] += 1 cnt_rec1.append(rec[0] / sum_kws) cnt_rec3.append(rec[2] / sum_kws) cnt_rec5.append(rec[4] / sum_kws) except tf.errors.OutOfRangeError: print( 'test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}' .format(np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5))) break
def _eval_epoch(sess, mode): if mode == 'val': data_iterator.switch_to_val_data(sess) else: data_iterator.switch_to_test_data(sess) refs, hypos = [], [] while True: try: fetches = [ batch['target_text'][:, 1:], infer_outputs.predicted_ids[:, :, 0] ] feed_dict = { tx.global_mode(): tf.estimator.ModeKeys.EVAL } target_texts_ori, output_ids = \ sess.run(fetches, feed_dict=feed_dict) target_texts = tx.utils.strip_special_tokens(target_texts_ori) output_texts = tx.utils.map_ids_to_strs( ids=output_ids, vocab=val_data.target_vocab) for hypo, ref in zip(output_texts, target_texts): hypos.append(hypo) refs.append([ref]) except tf.errors.OutOfRangeError: break return tx.evals.corpus_bleu_moses(list_of_references=refs, hypotheses=hypos)
def _eval(sess, epoch, data_tag): fetches = { "predicts": predicts, } mode = tf.estimator.ModeKeys.EVAL file_name = 'tmp/%s%d' % (data_tag, epoch) writer = CoNLLWriter(i2w, i2n) writer.start(file_name) data = data_dev if data_tag == 'dev' else data_test for batch in iterate_batch(data, config.batch_size, shuffle=False): word, char, ner, mask, length = batch feed_dict = { inputs: word, chars: char, targets: ner, masks: mask, seq_lengths: length, global_step: epoch, tx.global_mode(): mode, } rets = sess.run(fetches, feed_dict) predictions = rets['predicts'] writer.write(word, predictions, ner, length) writer.close() acc, precision, recall, f1 = scores.scores(file_name) print('%s acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' % (data_tag, acc, precision, recall, f1)) return acc, precision, recall, f1
def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=self.data_config._retrieval_candidates)[1]
def _eval_epoch(sess, mode): if mode == 'valid': data_iterator.switch_to_val_data(sess) else: data_iterator.switch_to_test_data(sess) refs, hypos = [], [] i = 0 while True: try: fetches = [ data_batch['target_text_ids'][:, 1:], infer_outputs.predicted_ids[:, :, 0] ] feed_dict = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} target_ids, output_ids = sess.run(fetches, feed_dict=feed_dict) target_texts = tx.utils.map_ids_to_strs( ids=target_ids, vocab=valid_data.target_vocab) output_texts = tx.utils.map_ids_to_strs( ids=output_ids, vocab=valid_data.target_vocab) if i == 0: print('Target and Output texts') print(target_texts) print(output_texts) i += 1 for hypo, ref in zip(output_texts, target_texts): hypos.append(hypo) refs.append([ref]) except tf.errors.OutOfRangeError: break return tx.evals.corpus_bleu(list_of_references=refs, hypotheses=hypos)
def _test_epochs_bleu(sess, epoch): iterator.switch_to_test_data(sess) bleu_prec = [[] for i in range(1, 5)] bleu_recall = [[] for i in range(1, 5)] def bleus(ref, sample): res = [] for weight in [[1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]: res.append(sentence_bleu([ref], sample, smoothing_function=SmoothingFunction().method7, weights=weight)) return res while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} beam_samples, beam_length, references, refs_cnt = \ sess.run([beam_sample_text, beam_lengths, data_batch['refs_text'][:, :, 1:], data_batch['refs_utterance_cnt']], feed_dict=feed) beam_samples = np.transpose(beam_samples, (0, 2, 1)) beam_samples = [[sample[:l] for sample, l in zip(beam, lens)] for beam, lens in zip(beam_samples.tolist(), beam_length)] references = [[ref[:ref.index(b'<EOS>')] for ref in refs[:cnt]] for refs, cnt in zip(references.tolist(), refs_cnt)] for beam, refs in zip(beam_samples, references): bleu_scores = np.array([[bleus(ref, sample) for i, ref in enumerate(refs)] for j, sample in enumerate(beam)]) bleu_scores = np.transpose(bleu_scores, (2, 0, 1)) for i in range(1, 5): bleu_i = bleu_scores[i] bleu_i_precision = bleu_i.max(axis=1).mean() bleu_i_recall = bleu_i.max(axis=0).mean() bleu_prec[i-1].append(bleu_i_precision) bleu_recall[i-1].append(bleu_i_recall) except tf.errors.OutOfRangeError: break bleu_prec = [np.mean(x) for x in bleu_prec] bleu_recall = [np.mean(x) for x in bleu_recall] print('epoch {}:'.format(epoch)) for i in range(1, 5): print(' -- bleu-{} prec={}, recall={}'.format( i, bleu_prec[i-1], bleu_recall[i-1]))
def _train_epoch(sess, epoch, step, smry_writer): fetches = { 'step': global_step, 'train_op': train_op, 'smry': summary_merged, 'loss': mle_loss, } while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } op = sess.run([batch], feed_dict) feed_dict = { src_input_ids: op[0]['src_input_ids'], src_segment_ids: op[0]['src_segment_ids'], tgt_input_ids: op[0]['tgt_input_ids'], labels: op[0]['tgt_labels'], learning_rate: utils.get_lr(step, lr), tx.global_mode(): tf.estimator.ModeKeys.TRAIN } fetches_ = sess.run(fetches, feed_dict=feed_dict) step, loss = fetches_['step'], fetches_['loss'] display_steps = 100 if step and step % display_steps == 0: # with open(os.path.join('/var/scratch/vro220','modeltime_check_loss.txt'),'a')as file_obj: # print(step,loss,file=file_obj) #logger.info('step: %d, loss: %.4f', step, loss) print('step: %d, loss: %.4f' % (step, loss)) smry_writer.add_summary(fetches_['smry'], global_step=step) if step and step % 1000 == 0: model_path = "/var/scratch/vro220/models10/model_" + str( step) + ".ckpt" print('saving model to %s' % model_path) saver.save(sess, model_path) # _eval_epoch(sess, epoch,step,mode='eval') except tf.errors.OutOfRangeError: break return step
def generate(sess, saver, fname=None): if tf.train.checkpoint_exists(FLAGS.model): saver.restore(sess, FLAGS.model) else: raise ValueError("cannot find checkpoint model") batch_size = train_data.batch_size dst = tfd.MultivariateNormalDiag( loc=tf.zeros([batch_size, config.latent_dims]), scale_diag=tf.ones([batch_size, config.latent_dims])) dcdr_states, latent_z = connector_stoch(dst) # to concatenate latent variable to input word embeddings def _cat_embedder(ids): embedding = decoder_embedder(ids) return tf.concat([embedding, latent_z], axis=1) vocab = train_data.vocab start_tokens = tf.ones(batch_size, tf.int32) * vocab.bos_token_id end_token = vocab.eos_token_id if config.decoder_hparams["type"] == "lstm": outputs, _, _ = decoder(initial_state=dcdr_states, decoding_strategy="infer_sample", embedding=_cat_embedder, max_decoding_length=100, start_tokens=start_tokens, end_token=end_token) else: outputs, _ = decoder(memory=dcdr_states, decoding_strategy="infer_sample", memory_sequence_length=tf.ones( tf.shape(dcdr_states)[0]), max_decoding_length=100, start_tokens=start_tokens, end_token=end_token) sample_tokens = vocab.map_ids_to_tokens(outputs.sample_id) sess.run(tf.tables_initializer()) mode_key = tf.estimator.ModeKeys.EVAL feed = {tx.global_mode(): mode_key} sample_tokens_ = sess.run(sample_tokens, feed_dict=feed) if fname is None: fh = sys.stdout else: fh = open(fname, 'w', encoding='utf-8') for sent in sample_tokens_: sent = list(sent) end_id = sent.index(vocab.eos_token) fh.write(' '.join(sent[:end_id + 1]) + '\n') fh.close()