Пример #1
0
    def generate_keyword_mask(self, context_ids):
        """Generate mask to only keep the related keywords' Q-value.
        """
        self.kg_adjacency_matrix = tf.cond(
            pred=tf.equal(tx.global_mode(), tf.estimator.ModeKeys.TRAIN),
            true_fn=lambda: self.train_kg_adjacency_matrix,
            false_fn=lambda: tf.cond(
                pred=tf.equal(tx.global_mode(), tf.estimator.ModeKeys.EVAL),
                true_fn=lambda: self.valid_kg_adjacency_matrix,
                false_fn=lambda: self.test_kg_adjacency_matrix))

        context_ids = tf.cast(context_ids, tf.int64)
        # shape of adj_matrix_context_ids: [_cur_keywords_len,]
        adj_matrix_context_ids = self.map_vocab_ids_to_adj_matrix_ids(
            context_ids)
        # shape of context_related_adj_matrix: [#adj_matrix_context_ids, adj_matrix_size]
        context_related_adj_matrix = tf.gather(self.kg_adjacency_matrix,
                                               adj_matrix_context_ids)
        # shape of keyword_mask: [adj_matrix_size,]
        keyword_mask = tf.reduce_max(context_related_adj_matrix, axis=0)

        num_related_keywords = tf.reduce_sum(
            tf.cast(tf.equal(keyword_mask, 1.), tf.float32))
        no_related_keywords = tf.equal(num_related_keywords, 0.)
        ones_tensor = tf.ones(shape=[self.adj_matrix_size])
        # if no related keywords for current context, keep all keywords' Q-value
        keyword_mask = tf.cond(pred=no_related_keywords,
                               true_fn=lambda: ones_tensor,
                               false_fn=lambda: keyword_mask)
        # remove the <PAD> dimension
        keyword_mask = keyword_mask[1:]

        return keyword_mask
Пример #2
0
    def _train_epoch(sess, initial=False):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        while True:
            try:
                # (1) Get data and yy sample
                fetches_data = {
                    'batch': batch,
                    'batch_size': batch_size,
                }
                feed_dict_data = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                rets_data = sess.run(fetches_data, feed_dict_data)


                # (2) Optimize loss
                feed_dict = {
                    #x1_ids: rets_data['batch']['x1_ids'],
                    x1_len: rets_data['batch']['x1_len'],
                    x1x4_ids: rets_data['batch']['x1x4_ids'],
                    x1x4_len: rets_data['batch']['x1x4_len'],
                    tau: config_train.tau,
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }

                fetches = {
                    'train_op': train_op,
                    'step': global_step,
                }
                fetches.update(loss_dict)

                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_train.display_steps

                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    _log_losses(rets, step)

                eval_steps = config_train.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _dev_epoch(sess)
                sample_steps = config_train.sample_steps
                if _is_head() and sample_steps > 0 and step % sample_steps == 0:
                    print('-----------testing-----------------')
                    _test_epoch(sess, step=step)

                ckpt_steps = config_train.checkpoint_steps
                if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0:
                    ckpt_fn = os.path.join(output_dir, 'model.ckpt')
                    ckpt_fn = saver.save(sess, ckpt_fn, global_step=step)
                    _log('Checkpoint to {}'.format(ckpt_fn))

            except tf.errors.OutOfRangeError:
                break
def _eval_epoch(sess, epoch, mode):

        references, hypotheses = [], []
        bsize = test_batch_size
        fetches = {
                'inferred_ids': inferred_ids,
            }
        bno=0
        while True:
            
            #print("Temp",temp)
            try:
              print("Batch",bno)
              feed_dict = {
              iterator.handle: iterator.get_handle(sess, 'eval'),
              tx.global_mode(): tf.estimator.ModeKeys.EVAL,
              }
              op = sess.run([batch],feed_dict)
              feed_dict = {
                   src_input_ids:op[0]['src_input_ids'],
                   src_segment_ids : op[0]['src_segment_ids'],
                   tx.global_mode(): tf.estimator.ModeKeys.EVAL
              }
              fetches_ = sess.run(fetches, feed_dict=feed_dict)
              labels = op[0]['tgt_labels']
              hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
              references.extend(r.tolist() for r in labels)
              hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
              references = utils.list_strip_eos(references, eos_token_id)
              bno = bno+1
              
            except tf.errors.OutOfRangeError:
                break


        if mode == 'eval':
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            fname = os.path.join(model_dir, 'tmp.eval')
            hypotheses = tx.utils.str_join(hypotheses)
            references = tx.utils.str_join(references)
            hyp_fn, ref_fn = tx.utils.write_paired_text(
                hypotheses, references, fname, mode='s')
            eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
            eval_bleu = 100. * eval_bleu
            logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu)
            print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu))

            if eval_bleu > best_results['score']:
                logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu)
                best_results['score'] = eval_bleu
                best_results['epoch'] = epoch
                model_path = os.path.join(model_dir, 'best-model.ckpt')
                logger.info('saving model to %s', model_path)
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)
Пример #4
0
 def train(self):
     batch = self.iterator.get_next()
     loss_t, acc_t, _ = self.predict_keywords(batch)
     kw_saver = tf.train.Saver()
     loss, acc, _ = self.forward(batch)
     retrieval_step = tf.Variable(0, name='retrieval_step')
     train_op = tx.core.get_train_op(loss,
                                     global_step=retrieval_step,
                                     hparams=self.config.opt_hparams)
     max_val_acc, stopping_flag = 0, 0
     with tf.Session(config=self.gpu_config) as sess:
         sess.run(tf.tables_initializer())
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
         kw_saver.restore(sess, self.config._kernel_save_path)
         saver = tf.train.Saver()
         for epoch_id in range(self.config._max_epoch):
             self.iterator.switch_to_train_data(sess)
             cur_step = 0
             cnt_acc, cnt_kwacc = [], []
             while True:
                 try:
                     cur_step += 1
                     feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
                     loss, acc_, acc_kw = sess.run([train_op, acc, acc_t],
                                                   feed_dict=feed)
                     cnt_acc.append(acc_)
                     cnt_kwacc.append(acc_kw)
                     if cur_step % 200 == 0:
                         print('batch {}, loss={}, acc1={}, kw_acc1={}'.
                               format(cur_step, loss,
                                      np.mean(cnt_acc[-200:]),
                                      np.mean(cnt_kwacc[-200:])))
                 except tf.errors.OutOfRangeError:
                     break
             self.iterator.switch_to_val_data(sess)
             cnt_acc, cnt_kwacc = [], []
             while True:
                 try:
                     feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
                     acc_, acc_kw = sess.run([acc, acc_t], feed_dict=feed)
                     cnt_acc.append(acc_)
                     cnt_kwacc.append(acc_kw)
                 except tf.errors.OutOfRangeError:
                     mean_acc = np.mean(cnt_acc)
                     print('valid acc1={}, kw_acc1={}'.format(
                         mean_acc, np.mean(cnt_kwacc)))
                     if mean_acc > max_val_acc:
                         max_val_acc = mean_acc
                         saver.save(sess, self.config._save_path)
                     else:
                         stopping_flag += 1
                     break
             if stopping_flag >= self.config._early_stopping:
                 break
def _train_epoch(sess, epoch, step, smry_writer):

    fetches = {
        'step': global_step,
        'train_op': train_op,
        'smry': summary_merged,
        'mle_loss': mle_loss,
        'type_cls_loss': type_cls_loss,
        'conn_cls_loss': conn_cls_loss,
        'total_loss': total_loss,
    }

    print("------ Epoch number", epoch + 1, "out of", total_epochs, "epochs ------")
    for train_batch in range(num_train_batches_in_epoch):
        try:
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, 'train'),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            }
            op = sess.run([batch], feed_dict)

            feed_dict = {
                src_input_ids: op[0]['src_input_ids'],
                src_segment_ids: op[0]['src_segment_ids'],
                tgt_input_ids: op[0]['tgt_input_ids'],
                labels: op[0]['tgt_labels'],
                type_label: op[0]['type_label'],
                conn_label: op[0]['conn_label'],
                learning_rate: utils.get_lr(step, lr),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)
            step, m_loss, t_loss, c_loss = fetches_['step'], fetches_['mle_loss'], fetches_['type_cls_loss'], \
                                           fetches_['conn_cls_loss']

            if step and step % display_steps == 0:
                logger.info('batch: %d/%d, mle_loss: %.4f, type_cls_loss: %.4f, conn_cls_loss: %.4f', train_batch,
                            num_train_batches_in_epoch, m_loss, t_loss, c_loss)
                print('batch: %d/%d, mle_loss: %.4f, type_cls_loss: %.4f, conn_cls_loss: %.4f' % (train_batch+1,
                                                                                          num_train_batches_in_epoch,
                                                                                          m_loss, t_loss, c_loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

        except tf.errors.OutOfRangeError:
            break

    model_path = model_dir + "/model_" + str(step) + ".ckpt"
    logger.info('saving model to %s', model_path)
    print('saving model to %s' % model_path)
    saver.save(sess, model_path)
    print("---EVAL---")
    _eval_epoch(sess, epoch, mode='eval')

    return step
Пример #6
0
    def train(self):
        batch = self.iterator.get_next()
        kw_loss, kw_acc, _ = self.predict_keywords(batch)
        kw_saver = tf.train.Saver()
        loss, acc, rank = self.forward_response_retrieval(batch)
        op_step = tf.Variable(0, name='retrieval_step')
        train_op = tx.core.get_train_op(
            loss,
            global_step=op_step,
            hparams=self.model_config._retrieval_opt_hparams)
        max_val_acc = 0.
        with tf.Session(config=self.gpu_config) as sess:
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            kw_saver.restore(sess, self.model_config._kp_save_path)
            saver = tf.train.Saver()
            for epoch_id in range(self.model_config._max_epoch):
                self.iterator.switch_to_train_data(sess)
                cur_step = 0
                cnt_acc = []
                while True:
                    try:
                        cur_step += 1
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
                        cnt_acc.append(acc_)
                        if cur_step % 200 == 0:
                            logs_loss_acc = 'batch {}, loss={}, acc1={}'.format(
                                cur_step, loss, np.mean(cnt_acc[-200:]))
                            add_log(self.logs_save_path, logs_loss_acc)
                    except tf.errors.OutOfRangeError:
                        break

                self.iterator.switch_to_val_data(sess)
                cnt_acc, cnt_kwacc = [], []
                while True:
                    try:
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
                        acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed)
                        cnt_acc.append(acc_)
                        cnt_kwacc.append(kw_acc_)
                    except tf.errors.OutOfRangeError:
                        mean_acc = np.mean(cnt_acc)
                        logs_loss_acc = 'epoch_id {}, valid acc1={}, kw_acc1={}'.format(
                            epoch_id + 1, mean_acc, np.mean(cnt_kwacc))
                        add_log(self.logs_save_path, logs_loss_acc)
                        if mean_acc > max_val_acc:
                            max_val_acc = mean_acc
                            saver.save(sess,
                                       self.model_config._retrieval_save_path)
                        break
Пример #7
0
    def train_keywords(self):
        batch = self.iterator.get_next()
        acc = self.predict_keywords(batch)
        with tf.Session(config=self.gpu_config) as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            self.iterator.switch_to_train_data(sess)

            batchid = 0
            while True:
                try:
                    batchid += 1
                    if batchid % 200 == 0:
                        print(batchid)
                    feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
                    source_keywords, target_keywords = sess.run([
                        batch['context_text_ids'], batch['keywords_text_ids']
                    ],
                                                                feed_dict=feed)
                    for i in range(len(source_keywords)):
                        for skw_id in source_keywords[i]:
                            if skw_id == 0:
                                break
                            for tkw_id in target_keywords[i]:
                                if skw_id >= 3 and tkw_id >= 3:
                                    tkw = self.config._vocab[tkw_id - 4]
                                    if tkw in self.data_config._keywords_candi:
                                        tkw_id = self.data_config._keywords_dict[
                                            tkw]
                                        self.pmi_matrix[skw_id][tkw_id] += 1

                except tf.errors.OutOfRangeError:
                    break
            self.pmi_matrix += 0.5
            self.pmi_matrix = self.pmi_matrix / (
                np.sum(self.pmi_matrix, axis=0) + 1)
            with open(self.config._matrix_save_path, 'wb') as f:
                pickle.dump(self.pmi_matrix, f)

            self.pmi_matrix = tf.convert_to_tensor(self.pmi_matrix,
                                                   dtype=tf.float32)
            self.iterator.switch_to_val_data(sess)
            cnt_acc = []
            while True:
                try:
                    feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
                    acc_ = sess.run(acc, feed_dict=feed)
                    cnt_acc.append(acc_)
                except tf.errors.OutOfRangeError:
                    print('valid acc1={}'.format(np.mean(cnt_acc)))
                    break
Пример #8
0
    def _dev_epoch(sess):
        """Evaluates on the dev set.
        """
        iterator.restart_dataset(sess, 'dev')

        results = tx.utils.AverageRecorder()
        nsamples = 0
        fetches = {}
        fetches.update(loss_dict)
        # i = 0

        while True:
            try:

                # (1) Get data and yy sample
                fetches_data = {
                    'batch': batch,
                    'batch_size': batch_size,
                }
                feed_dict_data = {
                    iterator.handle: iterator.get_handle(sess, 'dev'),
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                rets_data = sess.run(fetches_data, feed_dict_data)


                # (2) eval loss
                feed_dict = {
                    #x1_ids: rets_data['batch']['x1_ids'],
                    x1_len: rets_data['batch']['x1_len'],
                    x1x4_ids: rets_data['batch']['x1x4_ids'],
                    x1x4_len: rets_data['batch']['x1x4_len'],
                    tau: config_train.tau,
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }

                rets = sess.run(fetches, feed_dict)

                results.add(rets, weight=rets_data['batch_size'])
                nsamples += rets_data['batch_size']
            except tf.errors.OutOfRangeError:
                break

        _log_losses(results.avg())
        _log('nsamples: %d' % nsamples)

        avg_loss = results.avg('loss')
        if FLAGS.do_train and avg_loss < dev_best['loss']:
            dev_best.update(results.avg())
            ckpt_fn = os.path.join(output_dir, 'model_best.ckpt')
            ckpt_fn = saver_best.save(sess, ckpt_fn)
            _log('Checkpoint best to {}'.format(ckpt_fn))
    def train_keywords(self):
        batch = self.iterator.get_next()
        loss, acc = self.predict_keywords(batch)
        op_step = tf.Variable(0, name='op_step')
        train_op = tx.core.get_train_op(loss,
                                        global_step=op_step,
                                        hparams=self.config.kernel_opt_hparams)
        max_val_acc, stopping_flag = 0, 0
        self.saver = tf.train.Saver()
        with tf.Session(config=self.gpu_config) as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            for epoch_id in range(self.config._max_epoch):
                self.iterator.switch_to_train_data(sess)
                cur_step = 0
                cnt_acc = []
                while True:
                    try:
                        cur_step += 1
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
                        loss_, acc_ = sess.run([train_op, acc], feed_dict=feed)
                        cnt_acc.append(acc_)
                        if cur_step % 100 == 0:
                            print('batch {}, loss={}, acc1={}'.format(
                                cur_step, loss_, np.mean(cnt_acc[-100:])))
                    except tf.errors.OutOfRangeError:
                        break

                self.iterator.switch_to_val_data(sess)
                cnt_acc = []
                while True:
                    try:
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
                        acc_ = sess.run(acc, feed_dict=feed)
                        cnt_acc.append(acc_)
                    except tf.errors.OutOfRangeError:
                        mean_acc = np.mean(cnt_acc)
                        if mean_acc > max_val_acc:
                            max_val_acc = mean_acc
                            self.saver.save(sess,
                                            self.config._kernel_save_path)
                        else:
                            stopping_flag += 1
                        print('epoch_id {}, valid acc1={}'.format(
                            epoch_id + 1, mean_acc))
                        break
                if stopping_flag >= self.config._early_stopping:
                    break
Пример #10
0
def _train_epoch(sess, epoch, step, smry_writer):

    fetches = {
        'step': global_step,
        'train_op': train_op,
        'smry': summary_merged,
        'loss': mle_loss,
    }

    while True:
        try:
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, 'train'),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            }
            op = sess.run([batch], feed_dict)
            feed_dict = {
                src_input_ids: op[0]['src_input_ids'],
                src_segment_ids: op[0]['src_segment_ids'],
                tgt_input_ids: op[0]['tgt_input_ids'],
                labels: op[0]['tgt_labels'],
                learning_rate: utils.get_lr(step, lr),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)
            step, loss = fetches_['step'], fetches_['loss']
            if step and step % display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % checkpoint_steps == 0:
                model_path = model_dir + "/model_" + str(step) + ".ckpt"
                logger.info('saving model to %s', model_path)
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)
            if step > 40000 and step % eval_steps == 0:
                _eval_epoch(sess, epoch, mode='eval')

            if step and step <= 40000 and step % (test_steps * 2) == 0:
                _eval_epoch(sess, epoch, mode='test')
            if step > 40000 and step % test_steps == 0:
                _eval_epoch(sess, epoch, mode='test')
        except tf.errors.OutOfRangeError:
            break

    return step
    def _train_epoch(sess):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        fetches = {'loss': train_op, 'step': global_step}

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }
                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_train.display_steps
                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    tf.logging.info('step:%d; loss:%f' % (step, rets['loss']))

                eval_steps = config_train.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _dev_epoch(sess)

                ckpt_steps = config_train.checkpoint_steps
                if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0:
                    ckpt_fn = os.path.join(FLAGS.output_dir, 'model.ckpt')
                    ckpt_fn = saver.save(sess, ckpt_fn, global_step=step)
                    tf.logging.info('Checkpoint to {}'.format(ckpt_fn))

            except tf.errors.OutOfRangeError:
                break
Пример #12
0
    def _run(sess, mode):
        fetches = {
            'accu': accu,
            'batch_size': batch_size,
            'step': global_step,
            'loss': loss,
        }

        if mode == 'train':
            fetches['train_op'] = train_op
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'train'),
                        tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    }
                    rets = sess.run(fetches, feed_dict)
                    if rets['step'] % 50 == 0:
                        tf.logging.info('step:%d loss:%f' %
                                        (rets['step'], rets['loss']))
                    if rets['step'] == num_train_steps:
                        break
                except tf.errors.OutOfRangeError:
                    break

        if mode == 'eval':
            cum_acc = 0.0
            nsamples = 0
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'eval'),
                        tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                    }
                    rets = sess.run(fetches, feed_dict)

                    cum_acc += rets['accu'] * rets['batch_size']
                    nsamples += rets['batch_size']
                except tf.errors.OutOfRangeError:
                    break

            tf.logging.info('dev accu: {}'.format(cum_acc / nsamples))

        if mode == 'test':
            _all_preds = []
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'test'),
                        tx.context.global_mode():
                        tf.estimator.ModeKeys.PREDICT,
                    }
                    _preds = sess.run(preds, feed_dict=feed_dict)
                    _all_preds.extend(_preds.tolist())
                except tf.errors.OutOfRangeError:
                    break

            output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
            with tf.gfile.GFile(output_file, "w") as writer:
                writer.write('\n'.join(str(p) for p in _all_preds))
Пример #13
0
    def _train_qnet(self, feed_dict):
        minibatch = self._replay_memory.get(self._sample_batch_size)
        observ_batch = np.array([data['observ'] for data in minibatch])
        action_batch = np.array([data['action'] for data in minibatch])
        reward_batch = np.array([data['reward'] for data in minibatch])
        terminal_batch = np.array([data['terminal'] for data in minibatch])
        next_observ_batch = \
            np.array([data['next_observ'] for data in minibatch])

        target_qvalue = self._sess.run(self._target_outputs['qvalues'],
                                       feed_dict={
                                           self._observ_inputs:
                                           next_observ_batch,
                                           tx.global_mode():
                                           tf.estimator.ModeKeys.PREDICT
                                       })

        y_batch = reward_batch
        for i in range(self._sample_batch_size):
            if not terminal_batch[i]:
                y_batch[i] += self._discount_factor * np.max(target_qvalue[i])

        feed_dict_ = {
            self._observ_inputs: observ_batch,
            self._y_inputs: y_batch,
            self._action_inputs: action_batch
        }
        feed_dict_.update(feed_dict or {})

        self._sess.run(self._train_op, feed_dict=feed_dict_)

        self._update_target(feed_dict)
    def _eval_epoch(sess, mode):
        """`mode` is one of {'val', 'test'}
        """
        iterator.restart_dataset(sess, mode)

        refs, hypos = [], []
        while True:
            try:
                fetches = [
                    batch['target_text'][:, 1:],
                    infer_outputs.predicted_ids[:, :, 0]
                ]
                feed_dict = {
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                    iterator.handle: iterator.get_handle(sess, mode)
                }
                target_texts, output_ids = \
                    sess.run(fetches, feed_dict=feed_dict)

                target_texts = tx.utils.strip_special_tokens(target_texts)
                output_texts = tx.utils.map_ids_to_strs(
                    ids=output_ids, vocab=val_data.target_vocab)

                for hypo, ref in zip(output_texts, target_texts):
                    hypos.append(hypo)
                    refs.append([ref])
            except tf.errors.OutOfRangeError:
                break

        return tx.evals.corpus_bleu_moses(list_of_references=refs,
                                          hypotheses=hypos)
Пример #15
0
 def _qvalues_from_target(self, observ):
     return self._sess.run(self._target_outputs['qvalues'],
                           feed_dict={
                               self._observ_inputs: np.array([observ]),
                               tx.global_mode():
                               tf.estimator.ModeKeys.PREDICT
                           })
Пример #16
0
    def _train_epoch(sess, summary_writer, mode, train_op, summary_op,
                     mle_outputs):
        print('in _train_epoch')

        data_iterator.restart_dataset(sess, mode)
        feed_dict = {
            tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            data_iterator.handle: data_iterator.get_handle(sess, mode),
        }

        cnt = 0

        while True:
            try:
                loss, summary, tf_outputs, ground_truth, _gamma = \
                    sess.run((train_op, summary_op, mle_outputs, tgt_y_ids, gamma), feed_dict)

                step = tf.train.global_step(sess, global_step)

                print('step {:d}: loss = {:.6f} gamma = {:.4f}'.format(
                    step, loss, _gamma))

                summary_writer.add_summary(summary, step)

                # if step % config_train.steps_per_eval == 0:
                #     _eval_epoch(sess, summary_writer, 'val')
                #     # _eval_epoch(sess, summary_writer, 'test')

                # if step > 921 and (step % 100 == 0):
                #     _eval_epoch(sess, summary_writer, 'val')

            except tf.errors.OutOfRangeError:
                break

        print('end _train_epoch')
    def _run_epoch(sess, data_iter, epoch, is_train=False):
        loss = 0.
        iters = 0

        fetches = {"mle_loss": mle_loss}
        if is_train:
            fetches["train_op"] = train_op

        mode = (tf.estimator.ModeKeys.TRAIN
                if is_train else tf.estimator.ModeKeys.EVAL)

        for _, (x, y) in enumerate(data_iter):
            batch_size = x.shape[0]
            feed_dict = {
                inputs: x,
                targets: y,
                learning_rate: lr,
                tx.global_mode(): mode,
            }

            rets = sess.run(fetches, feed_dict)
            loss += rets["mle_loss"]
            iters += batch_size

        ppl = np.exp(loss / iters)
        return ppl
Пример #18
0
	def _get_align(sess, mode):
		print('in _get_align')

		data_iterator.restart_dataset(sess, mode)
		feed_dict = {
			tx.global_mode(): tf.estimator.ModeKeys.EVAL,
			data_iterator.handle: data_iterator.get_handle(sess, mode),
		}

		with open('align.pkl', 'wb') as out_file:
			while True:
				try:
					batch = sess.run(data_batch, feed_dict)
					sd_texts, sent_texts = (
						[batch['{}{}_text'.format(field, ref_strs[1])]
						 for field in fields]
						for fields in (sd_fields, sent_fields))
					aligns = batch_get_align(*(sd_texts + sent_texts))
					sd_texts, sent_texts = (
						[batch_strip_special_tokens_of_list(texts)
						 for texts, field in zip(all_texts, fields)]
						for all_texts, fields in zip(
						(sd_texts, sent_texts), (sd_fields, sent_fields)))
					if FLAGS.verbose:
						batch_print_align(*(sd_texts + sent_texts + [aligns]))
					for align in aligns:
						pickle.dump(align, out_file)

				except tf.errors.OutOfRangeError:
					break

		print('end _get_align')
Пример #19
0
    def _train_epoch(sess):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        fetches = {
            'train_op': train_op,
            'loss': loss,
            'batch_size': batch_size,
            'step': global_step
        }

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }
                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_data.display_steps
                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    tf.logging.info('step:%d; loss:%f;' % (step, rets['loss']))

                eval_steps = config_data.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _eval_epoch(sess)

            except tf.errors.OutOfRangeError:
                break
Пример #20
0
 def test(self):
     batch = self.iterator.get_next()
     loss, acc, rank = self.forward(batch)
     with tf.Session(config=self.gpu_config) as sess:
         sess.run(tf.tables_initializer())
         self.saver = tf.train.Saver()
         self.saver.restore(sess, self.config._save_path)
         self.iterator.switch_to_test_data(sess)
         rank_cnt = []
         while True:
             try:
                 feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
                 ranks, labels = sess.run([rank, batch['label']],
                                          feed_dict=feed)
                 for i in range(len(ranks)):
                     rank_cnt.append(np.where(ranks[i] == labels[i])[0][0])
             except tf.errors.OutOfRangeError:
                 rec = [0, 0, 0, 0, 0]
                 MRR = 0
                 for rank in rank_cnt:
                     for i in range(5):
                         rec[i] += (rank <= i)
                     MRR += 1 / (rank + 1)
                 print(
                     'test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'
                     .format(rec[0] / len(rank_cnt), rec[2] / len(rank_cnt),
                             rec[4] / len(rank_cnt), MRR / len(rank_cnt)))
                 break
Пример #21
0
    def _train_epoch(sess, summary_writer, mode, train_ops, summary_ops,
                     names):
        print('in _train_epoch')

        data_iterator.restart_dataset(sess, mode)
        feed_dict = {
            tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            data_iterator.handle: data_iterator.get_handle(sess, mode),
        }

        while True:
            try:
                losses, summaries = sess.run((train_ops, summary_ops),
                                             feed_dict)

                step = tf.train.global_step(sess, global_step)

                print('step {:d}:\t{}'.format(
                    step, '\t'.join('{}: {:.6f}'.format(name, losses[name])
                                    for name in names)))

                for summary in summaries.values():
                    summary_writer.add_summary(summary, step)

                if step % config_train.steps_per_eval == 0:
                    _eval_epoch(sess, summary_writer, 'val')

            except tf.errors.OutOfRangeError:
                break

        print('end _train_epoch')
Пример #22
0
    def train_epoch(self, sess, summary_writer, mode, train_op, summary_op):
        print("in _train_epoch")

        self.data_iterator.restart_dataset(sess, mode)

        feed_dict = {
            tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            self.data_iterator.handle:
            self.data_iterator.get_handle(sess, mode),
        }

        while True:
            try:
                loss, summary = sess.run((train_op, summary_op), feed_dict)

                step = tf.train.global_step(sess, self.global_step)

                print("step {:d}: loss = {:.6f}".format(step, loss))

                summary_writer.add_summary(summary, step)

                # if step % config_train.steps_per_eval == 0:
                #     _eval_epoch(sess, summary_writer, 'val')

            except tf.errors.OutOfRangeError:
                break

        print("end _train_epoch")
Пример #23
0
    def test_keywords(self):
        batch = self.iterator.get_next()
        loss, acc, kws = self.predict_keywords(batch)
        saver = tf.train.Saver()
        with tf.Session(config=self.gpu_config) as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            saver.restore(sess, self.config._neural_save_path)
            self.iterator.switch_to_test_data(sess)
            cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []
            while True:
                try:
                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
                    acc_, kw_ans, kw_labels = sess.run(
                        [acc, kws, batch['keywords_text_ids']], feed_dict=feed)
                    cnt_acc.append(acc_)
                    rec = [0, 0, 0, 0, 0]
                    sum_kws = 0
                    for i in range(len(kw_ans)):
                        sum_kws += sum(kw_labels[i] > 3)
                        for j in range(5):
                            if kw_ans[i][j] in kw_labels[i]:
                                for k in range(j, 5):
                                    rec[k] += 1
                    cnt_rec1.append(rec[0] / sum_kws)
                    cnt_rec3.append(rec[2] / sum_kws)
                    cnt_rec5.append(rec[4] / sum_kws)

                except tf.errors.OutOfRangeError:
                    print(
                        'test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'
                        .format(np.mean(cnt_acc), np.mean(cnt_rec1),
                                np.mean(cnt_rec3), np.mean(cnt_rec5)))
                    break
Пример #24
0
    def _eval_epoch(sess, mode):
        if mode == 'val':
            data_iterator.switch_to_val_data(sess)
        else:
            data_iterator.switch_to_test_data(sess)

        refs, hypos = [], []
        while True:
            try:
                fetches = [
                    batch['target_text'][:, 1:],
                    infer_outputs.predicted_ids[:, :, 0]
                ]
                feed_dict = {
                    tx.global_mode(): tf.estimator.ModeKeys.EVAL
                }
                target_texts_ori, output_ids = \
                    sess.run(fetches, feed_dict=feed_dict)

                target_texts = tx.utils.strip_special_tokens(target_texts_ori)
                output_texts = tx.utils.map_ids_to_strs(
                    ids=output_ids, vocab=val_data.target_vocab)

                for hypo, ref in zip(output_texts, target_texts):
                    hypos.append(hypo)
                    refs.append([ref])
            except tf.errors.OutOfRangeError:
                break

        return tx.evals.corpus_bleu_moses(list_of_references=refs,
                                          hypotheses=hypos)
Пример #25
0
def _eval(sess, epoch, data_tag):
    fetches = {
        "predicts": predicts,
    }
    mode = tf.estimator.ModeKeys.EVAL
    file_name = 'tmp/%s%d' % (data_tag, epoch)
    writer = CoNLLWriter(i2w, i2n)
    writer.start(file_name)
    data = data_dev if data_tag == 'dev' else data_test
    for batch in iterate_batch(data, config.batch_size, shuffle=False):
        word, char, ner, mask, length = batch
        feed_dict = {
            inputs: word,
            chars: char,
            targets: ner,
            masks: mask,
            seq_lengths: length,
            global_step: epoch,
            tx.global_mode(): mode,
        }
        rets = sess.run(fetches, feed_dict)
        predictions = rets['predicts']
        writer.write(word, predictions, ner, length)
    writer.close()
    acc, precision, recall, f1 = scores.scores(file_name)
    print('%s acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
          (data_tag, acc, precision, recall, f1))
    return acc, precision, recall, f1
Пример #26
0
    def retrieve_init(self, sess):
        data_batch = self.iterator.get_next()
        loss, acc, _ = self.forward(data_batch)
        self.corpus = self.data_config._corpus
        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
        corpus_iterator = tx.data.DataIterator(self.corpus_data)
        batch = corpus_iterator.get_next()
        corpus_embed = self.embedder(batch['corpus_text_ids'])
        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
        self.corpus_code = np.zeros([0, self.config._code_len])
        corpus_iterator.switch_to_dataset(sess)
        sess.run(tf.tables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess, self.config._save_path)
        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
        while True:
            try:
                utter_code_ = sess.run(utter_code, feed_dict=feed)
                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
            except tf.errors.OutOfRangeError:
                break

        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))

        history_ids = self.vocab.map_tokens_to_ids(self.history_input)
        history_embed = self.embedder(history_ids)
        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
                                           sequence_length_minor=self.minor_length_input,
                                           sequence_length_major=self.major_length_input)[1]
        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
        feature_code = self.linear_matcher(select_corpus * history_code)
        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=self.data_config._retrieval_candidates)[1]
Пример #27
0
    def _eval_epoch(sess, mode):
        if mode == 'valid':
            data_iterator.switch_to_val_data(sess)
        else:
            data_iterator.switch_to_test_data(sess)

        refs, hypos = [], []
        i = 0
        while True:
            try:
                fetches = [
                    data_batch['target_text_ids'][:, 1:],
                    infer_outputs.predicted_ids[:, :, 0]
                ]
                feed_dict = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
                target_ids, output_ids = sess.run(fetches, feed_dict=feed_dict)

                target_texts = tx.utils.map_ids_to_strs(
                    ids=target_ids, vocab=valid_data.target_vocab)
                output_texts = tx.utils.map_ids_to_strs(
                    ids=output_ids, vocab=valid_data.target_vocab)
                if i == 0:
                    print('Target and Output texts')
                    print(target_texts)
                    print(output_texts)
                i += 1
                for hypo, ref in zip(output_texts, target_texts):
                    hypos.append(hypo)
                    refs.append([ref])
            except tf.errors.OutOfRangeError:
                break

        return tx.evals.corpus_bleu(list_of_references=refs, hypotheses=hypos)
Пример #28
0
    def _test_epochs_bleu(sess, epoch):
        iterator.switch_to_test_data(sess)

        bleu_prec = [[] for i in range(1, 5)]
        bleu_recall = [[] for i in range(1, 5)]

        def bleus(ref, sample):
            res = []
            for weight in [[1, 0, 0, 0],
                           [1, 0, 0, 0],
                           [0, 1, 0, 0],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]]:
                res.append(sentence_bleu([ref], sample,
                    smoothing_function=SmoothingFunction().method7,
                    weights=weight))
            return res

        while True:
            try:
                feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}

                beam_samples, beam_length, references, refs_cnt = \
                    sess.run([beam_sample_text, beam_lengths,
                        data_batch['refs_text'][:, :, 1:],
                        data_batch['refs_utterance_cnt']],
                    feed_dict=feed)

                beam_samples = np.transpose(beam_samples, (0, 2, 1))
                beam_samples = [[sample[:l] for sample, l in zip(beam, lens)]
                    for beam, lens in zip(beam_samples.tolist(), beam_length)]
                references = [[ref[:ref.index(b'<EOS>')] for ref in refs[:cnt]]
                    for refs, cnt in zip(references.tolist(), refs_cnt)]

                for beam, refs in zip(beam_samples, references):
                    bleu_scores = np.array([[bleus(ref, sample)
                        for i, ref in enumerate(refs)]
                        for j, sample in enumerate(beam)])
                    bleu_scores = np.transpose(bleu_scores, (2, 0, 1))

                    for i in range(1, 5):
                        bleu_i = bleu_scores[i]
                        bleu_i_precision = bleu_i.max(axis=1).mean()
                        bleu_i_recall = bleu_i.max(axis=0).mean()

                        bleu_prec[i-1].append(bleu_i_precision)
                        bleu_recall[i-1].append(bleu_i_recall)


            except tf.errors.OutOfRangeError:
                break

        bleu_prec = [np.mean(x) for x in bleu_prec]
        bleu_recall = [np.mean(x) for x in bleu_recall]

        print('epoch {}:'.format(epoch))
        for i in range(1, 5):
            print(' -- bleu-{} prec={}, recall={}'.format(
                i, bleu_prec[i-1], bleu_recall[i-1]))
Пример #29
0
def _train_epoch(sess, epoch, step, smry_writer):

    fetches = {
        'step': global_step,
        'train_op': train_op,
        'smry': summary_merged,
        'loss': mle_loss,
    }

    while True:
        try:
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, 'train'),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
            }
            op = sess.run([batch], feed_dict)
            feed_dict = {
                src_input_ids: op[0]['src_input_ids'],
                src_segment_ids: op[0]['src_segment_ids'],
                tgt_input_ids: op[0]['tgt_input_ids'],
                labels: op[0]['tgt_labels'],
                learning_rate: utils.get_lr(step, lr),
                tx.global_mode(): tf.estimator.ModeKeys.TRAIN
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)
            step, loss = fetches_['step'], fetches_['loss']
            display_steps = 100
            if step and step % display_steps == 0:
                #    with open(os.path.join('/var/scratch/vro220','modeltime_check_loss.txt'),'a')as  file_obj:
                #        print(step,loss,file=file_obj)
                #logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % 1000 == 0:
                model_path = "/var/scratch/vro220/models10/model_" + str(
                    step) + ".ckpt"
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)
            # _eval_epoch(sess, epoch,step,mode='eval')
        except tf.errors.OutOfRangeError:
            break

    return step
Пример #30
0
    def generate(sess, saver, fname=None):
        if tf.train.checkpoint_exists(FLAGS.model):
            saver.restore(sess, FLAGS.model)
        else:
            raise ValueError("cannot find checkpoint model")

        batch_size = train_data.batch_size

        dst = tfd.MultivariateNormalDiag(
            loc=tf.zeros([batch_size, config.latent_dims]),
            scale_diag=tf.ones([batch_size, config.latent_dims]))

        dcdr_states, latent_z = connector_stoch(dst)

        # to concatenate latent variable to input word embeddings
        def _cat_embedder(ids):
            embedding = decoder_embedder(ids)
            return tf.concat([embedding, latent_z], axis=1)

        vocab = train_data.vocab
        start_tokens = tf.ones(batch_size, tf.int32) * vocab.bos_token_id
        end_token = vocab.eos_token_id

        if config.decoder_hparams["type"] == "lstm":
            outputs, _, _ = decoder(initial_state=dcdr_states,
                                    decoding_strategy="infer_sample",
                                    embedding=_cat_embedder,
                                    max_decoding_length=100,
                                    start_tokens=start_tokens,
                                    end_token=end_token)
        else:
            outputs, _ = decoder(memory=dcdr_states,
                                 decoding_strategy="infer_sample",
                                 memory_sequence_length=tf.ones(
                                     tf.shape(dcdr_states)[0]),
                                 max_decoding_length=100,
                                 start_tokens=start_tokens,
                                 end_token=end_token)

        sample_tokens = vocab.map_ids_to_tokens(outputs.sample_id)
        sess.run(tf.tables_initializer())

        mode_key = tf.estimator.ModeKeys.EVAL
        feed = {tx.global_mode(): mode_key}
        sample_tokens_ = sess.run(sample_tokens, feed_dict=feed)
        if fname is None:
            fh = sys.stdout
        else:
            fh = open(fname, 'w', encoding='utf-8')

        for sent in sample_tokens_:
            sent = list(sent)
            end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        fh.close()