コード例 #1
0
 def _get_train_op(self):
     train_op = opt.get_train_op(
         loss=self._pg_loss,
         variables=self._trainable_variables,
         learning_rate=self._lr,
         hparams=self._hparams.optimization.todict())
     return train_op
コード例 #2
0
ファイル: optimization_test.py プロジェクト: zkmake520/texar
 def test_get_train_op(self):
     """Tests get_train_op.
     """
     var = tf.Variable(0.)
     loss = tf.nn.l2_loss(var)
     train_op = opt.get_train_op(loss)
     self.assertTrue(tf.contrib.framework.is_tensor(train_op))
コード例 #3
0
 def _get_train_op(self, loss):
     varlist = collect_trainable_variables([
         self._src_embedder, self._tgt_embedder, self._encoder,
         self._connector, self._decoder
     ])
     return get_train_op(loss,
                         variables=varlist,
                         hparams=self._hparams.optimization)
コード例 #4
0
ファイル: pg_agent.bak.py プロジェクト: ml-lab/Text_Infilling
    def __init__(self, actions, state_shape, hparams=None):
        AgentBase.__init__(self, actions, state_shape, hparams=hparams)
        self.discount_factor = self._hparams.discount_factor

        self.network = get_instance(
            self._hparams.network.type,
            {"hparams": self._hparams.network.hparams},
            module_paths=['texar.modules', 'texar.custom'])

        with tf.variable_scope(self.network.variable_scope):
            self.state_input = tf.placeholder(dtype=tf.float64,
                                              shape=[
                                                  None,
                                              ] + list(state_shape))

            self.action_inputs = tf.placeholder(dtype=tf.int32, shape=[
                None,
            ])

            self.qvalues = tf.placeholder(dtype=tf.float64, shape=[
                None,
            ])

            self.outputs = self.network(self.state_input)
            self.probs = tf.nn.softmax(self.outputs)

            self.loss = self._hparams.trainer.loss_fn(
                outputs=self.outputs,
                action_inputs=self.action_inputs,
                advantages=self.qvalues)
            self.trainer = opt.get_train_op(
                loss=self.loss,
                variables=None,
                hparams=self._hparams.trainer.optimization_hparams)

        self.record = list()

        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
コード例 #5
0
def main():
    """Entrypoint.
    """
    train_data = tx.data.PairedTextData(hparams=config_data.train)
    val_data = tx.data.PairedTextData(hparams=config_data.val)
    test_data = tx.data.PairedTextData(hparams=config_data.test)
    data_iterator = tx.data.TrainTestDataIterator(train=train_data,
                                                  val=val_data,
                                                  test=test_data)

    model = Seq2SeqAttn(train_data)
    optimizer = get_optimizer(model.parameters(), config_model.opt)
    train_op = get_train_op(optimizer, config_model.opt)

    def _train_epoch():
        data_iterator.switch_to_train_data()
        iterator = data_iterator.get_iterator()

        step = 0
        for batch in iterator:
            with torch.autograd.set_detect_anomaly(True):
                loss = model(batch, mode="train")
                loss.backward(retain_graph=True)
                train_op()
            if step % config_data.display == 0:
                print("step={}, loss={:.4f}".format(step, loss))
            step += 1

    def _eval_epoch(mode):
        if mode == 'val':
            data_iterator.switch_to_val_data()
            iterator = data_iterator.get_iterator()
        else:
            data_iterator.switch_to_test_data()
            iterator = data_iterator.get_iterator()

        refs, hypos = [], []
        for batch in iterator:
            infer_outputs = model(batch, mode="infer")
            output_ids = infer_outputs.sample_id
            target_texts_ori = [text[1:] for text in batch['target_text']]
            target_texts = tx.utils.strip_special_tokens(target_texts_ori,
                                                         is_token_list=True)
            output_texts = tx.data.vocabulary.map_ids_to_strs(
                ids=output_ids, vocab=val_data.target_vocab)

            for hypo, ref in zip(output_texts, target_texts):
                hypos.append(hypo)
                refs.append([ref])

        return tx.evals.corpus_bleu_moses(list_of_references=refs,
                                          hypotheses=hypos)

    best_val_bleu = -1.
    for i in range(config_data.num_epochs):
        _train_epoch()

        val_bleu = _eval_epoch('val')
        best_val_bleu = max(best_val_bleu, val_bleu)
        print('val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format(
            i, val_bleu, best_val_bleu))

        test_bleu = _eval_epoch('test')
        print('test epoch={}, BLEU={:.4f}'.format(i, test_bleu))

        print('=' * 50)
コード例 #6
0
 def _get_train_op(self):
     train_op = opt.get_train_op(
         loss=tf.reduce_sum(self._td_error**2),
         variables=self._qnet.trainable_variables,
         hparams=self._hparams.optimization.todict())
     return train_op