コード例 #1
0
    def train(self,
              t_x,
              t_y,
              v_x,
              v_y,
              lrv,
              char2idx,
              sess,
              epochs,
              batch_size=10,
              reset=True):

        idx2char = {k: v for v, k in char2idx.items()}
        v_y_g = [np.trim_zeros(v_y_t) for v_y_t in v_y]
        gold_out = [
            toolbox.generate_trans_out(v_y_t, idx2char) for v_y_t in v_y_g
        ]

        best_score = 0

        if reset or not os.path.isfile(self.trained + '_weights.index'):
            for epoch in range(epochs):
                Batch.train_seq2seq(sess,
                                    model=self.en_vec + self.trans_labels,
                                    decoding=self.feed_previous,
                                    batch_size=batch_size,
                                    config=self.trans_train,
                                    lr=self.trans_l_rate,
                                    lrv=lrv,
                                    data=[t_x] + [t_y])
                pred = Batch.predict_seq2seq(sess,
                                             model=self.en_vec + self.de_vec +
                                             self.trans_output,
                                             decoding=self.feed_previous,
                                             decode_len=self.decode_step,
                                             data=[v_x],
                                             argmax=True,
                                             batch_size=100)
                pred_out = [
                    toolbox.generate_trans_out(pre_t, idx2char)
                    for pre_t in pred
                ]

                c_scores = evaluation.trans_evaluator(gold_out, pred_out)

                print('epoch: %d' % (epoch + 1))

                print('ACC: %f' % c_scores[0])
                print('Token F score: %f' % c_scores[1])
                sys.stdout.flush()

                if c_scores[1] > best_score:
                    best_score = c_scores[1]
                    self.saver.save(sess,
                                    self.trained + '_weights',
                                    write_meta_graph=False)

        if best_score > 0 or not reset:
            self.saver.restore(sess, self.trained + '_weights')
コード例 #2
0
    def tag(self, t_x, char2idx, sess, batch_size=100):

        t_x = [t_x_t[:self.encode_step] for t_x_t in t_x]
        t_x = toolbox.pad_zeros(t_x, self.encode_step)

        idx2char = {k: v for v, k in char2idx.items()}

        pred = Batch.predict_seq2seq(sess,
                                     model=self.en_vec + self.de_vec +
                                     self.trans_output,
                                     decoding=self.feed_previous,
                                     decode_len=self.decode_step,
                                     data=[t_x],
                                     argmax=True,
                                     batch_size=batch_size)
        pred_out = [
            toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred
        ]

        return pred_out