Ejemplo n.º 1
0
    def eval(self, sess, args):
        if not os.path.exists(args.logdir + '/output'):
            os.makedirs(args.logdir + '/output')

        if args.eval_only:
            self.test_data = read_data.load_test_dataset(self.dataset_file)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)
            print_in_file("Saved")

        lr = self.learning_rate

        i_e, t_e, i_t, t_t = read_data.data_iterator(self.test_data,
                                                     self.num_steps,
                                                     self.length)

        sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

        batch_num = len(
            list(read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                          t_t)))
        logging.info('Evaluation Batch Num {}'.format(batch_num))

        f = open(os.path.join(args.logdir, "output_t.txt"), 'w+')
        i = 0

        for e_x, e_y, t_x, t_y in read_data.generate_batch(
                self.batch_size, i_e, t_e, i_t, t_t):
            feed_dict = {
                # self.input_e: e_x,
                self.inputs_t:
                np.maximum(np.log(t_x), 0),
                self.target_t:
                t_y,
                # self.targets_e: e_y,
                self.sample_t:
                np.maximum(np.log(sample_t), 0)
            }

            if i > 0 and i % (batch_num // 10) == 0:
                lr = lr * 2. / 3
            # correct_pred, deviation, pred_e, pred_t, d_loss, g_loss, gen_e_cost, gen_t_cost = sess.run(
            # 	[self.correct_pred, self.deviation, self.pred_e, self.pred_t, self.d_cost, self.g_cost,
            # 	self.gen_e_cost, self.gen_t_cost,self.disc_cost_1, self.gradient_penalty],
            # 	feed_dict = feed_dict)
            pred_t = sess.run(self.pred_t, feed_dict=feed_dict)

            # sum_correct_pred = sum_correct_pred + correct_pred
            # sum_iter = sum_iter + 1
            # sum_deviation = sum_deviation + deviation
            f.write('pred_t: ' +
                    '\t'.join([str(v) for v in tf.exp(pred_t).eval()]))
            f.write('\n')
            f.write(
                'targ_t: ' +
                '\t'.join([str(v) for v in np.array(t_y_list[i]).flatten()]))
            f.write('\n')

            i += 1
Ejemplo n.º 2
0
    def eval(self, sess, args):
        if not os.path.exists(args.logdir + '/output'):
            os.makedirs(args.logdir + '/output')

        # if args.eval_only:
        #     self.test_data = read_data.load_test_dataset(self.dataset_file)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)
            print_in_file("Saved")

        lr = self.learning_rate

        batch_size = 100

        input_event_data, target_event_data, input_time_data, target_time_data = read_data.data_iterator(
            self.test_data, self.num_steps, self.length)

        sample_t = read_data.generate_sample_t(batch_size, input_time_data,
                                               target_time_data)

        f = open(os.path.join(args.logdir, "output_e.txt"), 'w+')
        i = 0

        for e_x, e_y, t_x, t_y in read_data.generate_batch(
                self.batch_size, input_event_data, target_event_data,
                input_time_data, target_time_data):

            feed_dict = {
                self.input_e:
                e_x,
                self.inputs_t:
                np.maximum(np.log(t_x), 0),
                # self.target_t : t_y_list[i],
                # self.targets_e : e_y_list[i],
                self.sample_t:
                np.maximum(np.log(sample_t), 0)
            }

            pred_e = sess.run(self.pred_e, feed_dict=feed_dict)

            _, pred_e_index = tf.nn.top_k(pred_e, 1, name=None)
            f.write('pred_e: ' + '\t'.join([
                str(v)
                for v in tf.reshape(tf.squeeze(pred_e_index), [-1]).eval()
            ]))
            f.write('\n')
            f.write('targ_e: ' +
                    '\t'.join([str(v) for v in np.array(e_y[i]).flatten()]))
            f.write('\n')
Ejemplo n.º 3
0
gap = 6

alpha = args.alpha
gamma = args.gamma

for epoch in range(args.train_iter):
    logging.info('Training epoch: {}'.format(epoch))
    t0 = time.time()
    i = 0
    sess.run(model.running_MAE_vars_initializer)
    sess.run(model.running_precision_vars_initializer)
    sess.run(model.running_recall_vars_initializer)
    for e_x, e_y, t_x, t_y in read_data.generate_batch(args.batch_size,
                                                       train_i_e, train_t_e,
                                                       train_i_t, train_t_t):
        sample_t = read_data.generate_sample_t(args.batch_size, train_i_t,
                                               train_t_t)
        i += 1
        feed_dict = {
            model.alpha: alpha,
            model.gamma: gamma,
            model.inputs_e: e_x,
            model.inputs_t: np.maximum(np.log(np.maximum(t_x, 1e-4)), 0),
            model.target_t: t_y,
            model.target_e: e_y,
            model.sample_t: np.maximum(np.log(np.maximum(sample_t, 1e-4)), 0)
        }

        # Jointly train
        # If it's discriminator iter
        # if i % gap == 0:
        #     _, _ = sess.run([model.train_disc_op, model.train_w_clip_op], feed_dict=feed_dict)
Ejemplo n.º 4
0
    def train(self, sess, args):
        self.logdir = args.logdir + parse_time()
        while os.path.exists(self.logdir):
            time.sleep(random.randint(1, 5))
            self.logdir = args.logdir + parse_time()
        os.makedirs(self.logdir)

        if not os.path.exists('%s/logs' % self.logdir):
            os.makedirs('%s/logs' % self.logdir)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)

        lr = self.learning_rate

        for epoch in range(args.iters):
            '''training'''
            sum_correct_pred = 0.0
            sum_iter = 0.0
            sum_deviation = 0.0

            input_len, input_event_data, target_event_data, input_time_data, target_time_data = read_data.data_iterator(
                self.train_data, self.event_to_id, self.num_steps, self.length)

            batch_num, e_x_list, e_y_list, t_x_list, t_y_list = read_data.generate_batch(
                input_len, self.batch_size, input_event_data,
                target_event_data, input_time_data, target_time_data)

            _, sample_t_list = read_data.generate_sample_t(
                input_len, self.batch_size, input_time_data, target_time_data)

            g_iters = 5
            gap = g_iters + 1

            iterations = batch_num // gap
            print(iterations)

            for i in range(iterations):

                if i > 0 and i % (iterations // 10) == 0:
                    lr = lr * 2. / 3

                for j in range(g_iters):
                    feed_dict = {
                        self.input_e:
                        e_x_list[(i) * gap + j],
                        self.inputs_t:
                        np.maximum(np.log(t_x_list[(i) * gap + j]), 0),
                        self.target_t:
                        t_y_list[(i) * gap + j],
                        self.targets_e:
                        e_y_list[(i) * gap + j],
                        self.sample_t:
                        np.maximum(np.log(sample_t_list[(i) * gap + j]), 0)
                    }

                    _, correct_pred, deviation, pred_e, pred_t, d_loss, g_loss, gen_e_cost, gen_t_cost, disc_cost_1, gradient_penalty = sess.run(
                        [
                            self.g_train_op, self.correct_pred, self.deviation,
                            self.pred_e, self.pred_t, self.d_cost, self.g_cost,
                            self.gen_e_cost, self.gen_t_cost, self.disc_cost_1,
                            self.gradient_penalty
                        ],
                        feed_dict=feed_dict)

                feed_dict = {
                    self.input_e:
                    e_x_list[(i) * gap + g_iters],
                    self.inputs_t:
                    np.maximum(np.log(t_x_list[(i) * gap + g_iters]), 0),
                    self.target_t:
                    t_y_list[(i) * gap + g_iters],
                    self.targets_e:
                    e_y_list[(i) * gap + g_iters],
                    self.sample_t:
                    np.maximum(np.log(sample_t_list[(i) * gap + g_iters]), 0)
                }

                _ = sess.run(self.d_train_op, feed_dict=feed_dict)

                if self.cell_type == 'T_LSTMCell':
                    sess.run(self.clip_op)

                sum_correct_pred = sum_correct_pred + correct_pred
                sum_iter = sum_iter + 1
                sum_deviation = sum_deviation + deviation

                if i % (iterations // 10) == 0:
                    print(
                        '[epoch: %d, %d] precision: %f, deviation: %f, d_loss: %f, g_loss: %f'
                        % (epoch, int(i //
                                      (iterations // 10)), sum_correct_pred /
                           (sum_iter * self.batch_size * self.length),
                           sum_deviation /
                           (sum_iter * self.batch_size * self.length), d_loss,
                           g_loss))
                    print()
                    print(
                        'gen_e_loss: %f, gen_t_loss: %f, d_1_loss: %f, g_penal_loss: %f'
                        % (gen_e_cost, gen_t_cost, disc_cost_1,
                           gradient_penalty))
                    print()
            '''
			evaludation
			'''

            input_len, input_event_data, target_event_data, input_time_data, target_time_data = read_data.data_iterator(
                self.train_data, self.event_to_id, self.num_steps, self.length)

            batch_num, e_x_list, e_y_list, t_x_list, t_y_list = read_data.generate_batch(
                input_len, self.batch_size, input_event_data,
                target_event_data, input_time_data, target_time_data)

            _, sample_t_list = read_data.generate_sample_t(
                input_len, self.batch_size, input_time_data, target_time_data)

            iterations = batch_num
            sum_correct_pred = 0.0
            sum_iter = 0.0
            sum_deviation = 0.0

            for i in range(iterations):
                feed_dict = {
                    self.input_e: e_x_list[i],
                    self.inputs_t: np.maximum(np.log(t_x_list[i]), 0),
                    self.target_t: t_y_list[i],
                    self.targets_e: e_y_list[i],
                    self.sample_t: np.maximum(np.log(sample_t_list[i]), 0)
                }

                if i > 0 and i % (iterations // 10) == 0:
                    lr = lr * 2. / 3

                correct_pred, deviation, pred_e, pred_t, d_loss, g_loss, gen_e_cost, gen_t_cost, disc_cost_1, gradient_penalty = sess.run(
                    [
                        self.correct_pred, self.deviation, self.pred_e,
                        self.pred_t, self.d_cost, self.g_cost, self.gen_e_cost,
                        self.gen_t_cost, self.disc_cost_1,
                        self.gradient_penalty
                    ],
                    feed_dict=feed_dict)
                sum_correct_pred = sum_correct_pred + correct_pred
                sum_iter = sum_iter + 1
                sum_deviation = sum_deviation + deviation

                if i % (iterations // 10) == 0:
                    print(
                        '%f, precision: %f, deviation: %f, d_loss: %f, g_loss: %f'
                        % (i // (iterations // 10), sum_correct_pred /
                           (sum_iter * self.batch_size * self.length),
                           sum_deviation /
                           (sum_iter * self.batch_size * self.length), d_loss,
                           g_loss))

        self.save_model(sess, self.logdir, args.iters)
Ejemplo n.º 5
0
    def eval(self, sess, args):
        if not os.path.exists(args.logdir + '/output'):
            os.makedirs(args.logdir + '/output')

        if args.eval_only:
            self.test_data = read_data.load_test_dataset(self.dataset_file)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)
            print_in_file("Saved")

        lr = self.learning_rate

        batch_size = 100

        input_event_data, target_event_data, input_time_data, target_time_data = read_data.data_iterator(
            self.test_data, self.num_steps, self.length)

        sample_t = read_data.generate_sample_t(batch_size, input_time_data,
                                               target_time_data)

        f = open(os.path.join(args.logdir, "output.txt"), 'w+')
        # batch_num = len(list(read_data.generate_batch))

        for e_x, e_y, t_x, t_y in read_data.generate_batch(
                self.batch_size, input_event_data, target_event_data,
                input_time_data, target_time_data):

            feed_dict = {
                self.input_e:
                e_x,
                self.inputs_t:
                np.maximum(np.log(t_x), 0),
                # self.target_t : t_y_list[i],
                # self.targets_e : e_y_list[i],
                self.sample_t:
                np.maximum(np.log(sample_t), 0)
            }
            # correct_pred, deviation, pred_e, pred_t, d_loss, g_loss, gen_e_cost, gen_t_cost = sess.run(
            # 	[self.correct_pred, self.deviation, self.pred_e, self.pred_t, self.d_cost, self.g_cost,
            # 	self.gen_e_cost, self.gen_t_cost,self.disc_cost_1, self.gradient_penalty],
            # 	feed_dict = feed_dict)
            pred_e, pred_t, = sess.run([self.pred_e, self.pred_t],
                                       feed_dict=feed_dict)

            # sum_correct_pred = sum_correct_pred + correct_pred
            # sum_iter = sum_iter + 1
            # sum_deviation = sum_deviation + deviation
            _, pred_e_index = tf.nn.top_k(pred_e, 1, name=None)
            f.write('pred_e: ' + '\t'.join([
                str(v)
                for v in tf.reshape(tf.squeeze(pred_e_index), [-1]).eval()
            ]))
            f.write('\n')
            f.write('targ_e: ' +
                    '\t'.join([str(v) for v in np.array(e_y).flatten()]))
            f.write('\n')
            f.write('pred_t: ' +
                    '\t'.join([str(v) for v in tf.exp(pred_t).eval()]))
            f.write('\n')
            f.write('targ_t: ' +
                    '\t'.join([str(v) for v in np.array(t_y).flatten()]))
            f.write('\n')
Ejemplo n.º 6
0
    def train(self, sess, args):
        self.logdir = args.logdir + parse_time()
        while os.path.exists(self.logdir):
            time.sleep(random.randint(1, 5))
            self.logdir = args.logdir + parse_time()
        os.makedirs(self.logdir)

        if not os.path.exists('%s/logs' % self.logdir):
            os.makedirs('%s/logs' % self.logdir)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)

        self.lr = self.learning_rate

        for epoch in range(args.iters):
            '''training'''

            sess.run([
                self.running_precision_vars_initializer,
                self.running_recall_vars_initializer
            ])
            # re initialize the metric variables of metric.precision and metric.recall,
            # to calculate these metric for each epoch
            batch_precision, batch_recall = 0.0, 0.0
            average_deviation, sum_deviation = 0.0, 0.0
            d_loss, g_loss, gen_e_cost, gen_t_cost, huber_t_loss = 0.0, 0.0, 0.0, 0.0, 0.0

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.train_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

            i = 0
            gap = 6
            sum_iter = 0.0
            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info('Training batch num {}'.format(batch_num))

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):

                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                if i % gap == 0:

                    _, _ = sess.run([self.d_train_op, self.w_clip_op],
                                    feed_dict=feed_dict)

                else:
                    # train event cross-entropy
                    _ = sess.run(self.g_e_train_op, feed_dict=feed_dict)
                    # train time huber-loss
                    _ = sess.run(self.g_t_train_op, feed_dict=feed_dict)
                    # jointly update
                    _, _, _, deviation, batch_precision, batch_recall, d_loss, g_loss, gen_e_cost, gen_t_cost, huber_t_loss = sess.run(
                        [
                            self.g_train_op, self.batch_precision_op,
                            self.batch_recall_op, self.deviation,
                            self.batch_precision, self.batch_recall,
                            self.d_cost, self.g_cost, self.gen_e_cost,
                            self.gen_t_cost, self.huber_t_loss
                        ],
                        feed_dict=feed_dict)

                    sum_iter = sum_iter + 1.0
                    sum_deviation = sum_deviation + deviation
                    average_deviation = sum_deviation / sum_iter

                # if self.cell_type == 'T_LSTMCell':
                #     sess.run(self.clip_op)

                if i % (batch_num // 10) == 0:
                    logging.info(
                        '[epoch: {}, {}] precision: {}, recall: {}, deviation: {}'
                        .format(epoch,
                                float(i) / (batch_num // 10), batch_precision,
                                batch_recall, average_deviation))
                    logging.info(
                        'd_loss: {}, g_loss: {}, gen_e_loss: {}, gen_t_loss: {}, hunber_t_loss: {}'
                        .format(d_loss, g_loss, gen_e_cost, gen_t_cost,
                                huber_t_loss))
                i += 1
            '''evaluation'''
            sess.run([
                self.running_precision_vars_initializer,
                self.running_recall_vars_initializer
            ])
            # re initialize the metric variables of metric.precision and metric.recall,
            # to calculate these metric for each epoch

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.valid_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

            sum_iter = 0.0
            i = 0
            gen_cost_ratio = []
            t_cost_ratio = []
            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info('Evaluation Batch Num {}'.format(batch_num))

            self.lr = self.learning_rate

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):
                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                _, _, deviation, batch_precision, batch_recall, d_loss, g_loss, gen_e_cost, gen_t_cost, gen_t_cost_1, huber_t_loss = sess.run(
                    [
                        self.batch_precision_op, self.batch_recall_op,
                        self.deviation, self.batch_precision,
                        self.batch_recall, self.d_cost, self.g_cost,
                        self.gen_e_cost, self.gen_t_cost, self.gen_t_cost_1,
                        self.huber_t_loss
                    ],
                    feed_dict=feed_dict)

                sum_iter = sum_iter + 1.0
                sum_deviation = sum_deviation + deviation
                gen_cost_ratio.append(gen_t_cost / gen_e_cost)
                t_cost_ratio.append(gen_t_cost_1 / huber_t_loss)

                if i % (batch_num // 10) == 0:
                    logging.info(
                        '{}, precision {}, recall {}, deviation {}, d_loss {}, g_loss {}, huber_t_loss {}'
                        .format(
                            float(i) / (batch_num // 10), batch_precision,
                            batch_recall, sum_deviation / sum_iter, d_loss,
                            g_loss, huber_t_loss))
                i += 1
            self.alpha = tf.reduce_mean(gen_cost_ratio)
            self.gamma = tf.reduce_mean(t_cost_ratio)
            logging.info('alpha: {}, gamma: {}'.format(sess.run(self.alpha),
                                                       sess.run(self.gamma)))

        self.save_model(sess, self.logdir, args.iters)
Ejemplo n.º 7
0
    def train(self, sess, args):
        self.logdir = args.logdir + parse_time()
        while os.path.exists(self.logdir):
            time.sleep(random.randint(1, 5))
            self.logdir = args.logdir + parse_time()
        os.makedirs(self.logdir)

        if not os.path.exists('%s/logs' % self.logdir):
            os.makedirs('%s/logs' % self.logdir)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)

        self.lr = self.learning_rate

        for epoch in range(args.iters):
            '''training'''
            sess.run([
                self.running_precision_vars_initializer,
                self.running_recall_vars_initializer
            ])

            if epoch > 0 and epoch % (args.iters // 5) == 0:
                self.lr = self.lr * 2. / 3

            # re initialize the metric variables of metric.precision and metric.recall,
            # to calculate these metric for each epoch
            batch_precision, batch_recall = 0.0, 0.0

            sum_iter = 0.0

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.train_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

            i = 0
            hit_sum = 0.0
            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info("Total Batch Number {}".format(batch_num))

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):

                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                _, gen_e_cost, hit_count, batch_precision, batch_recall, = sess.run(
                    [
                        self.g_train_op, self.gen_e_cost, self.hit_count,
                        self.batch_precision_op, self.batch_recall_op
                    ],
                    feed_dict=feed_dict)
                sum_iter = sum_iter + 1
                hit_sum += hit_count
                # if self.cell_type == 'T_LSTMCell':
                #     sess.run(self.clip_op)

                if i % (batch_num // 10) == 0:
                    logging.info(
                        '[epoch: {}, {}] hit10: {}, gen_e_loss: {}, precision: {}, recall: {}'
                        .format(
                            epoch,
                            float(i) / batch_num, hit_sum /
                            (sum_iter * self.batch_size * self.length),
                            gen_e_cost, batch_precision, batch_recall))
                i += 1
            '''evaluation'''

            # re initialize the metric variables of metric.precision and metric.recall,
            # to calculate these metric for each epoch

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.valid_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

            sum_iter = 0.0
            hit_sum = 0.0
            i = 0

            self.lr = self.learning_rate
            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info(
                'Total Batch Number For Evaluation {}'.format(batch_num))

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):

                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                gen_e_cost, hit_count, batch_precision, batch_recall = sess.run(
                    [
                        self.gen_e_cost, self.hit_count, self.batch_precision,
                        self.batch_recall
                    ],
                    feed_dict=feed_dict)
                sum_iter = sum_iter + 1
                hit_sum += hit_count
                i += 1

                if i % (batch_num // 10) == 0:
                    logging.info(
                        '{}, gen_e_cost: {}, hit10: {}, precision: {}, recall: {}'
                        .format(
                            float(i) / batch_num, gen_e_cost, hit_sum /
                            (sum_iter * self.batch_size * self.length),
                            batch_precision, batch_recall))
        self.save_model(sess, self.logdir, args.iters)
Ejemplo n.º 8
0
    def train(self, sess, args):
        self.logdir = args.logdir + parse_time()
        while os.path.exists(self.logdir):
            time.sleep(random.randint(1, 5))
            self.logdir = args.logdir + parse_time()
        os.makedirs(self.logdir)

        if not os.path.exists('%s/logs' % self.logdir):
            os.makedirs('%s/logs' % self.logdir)

        if args.weights is not None:
            self.saver.restore(sess, args.weights)

        self.lr = self.learning_rate

        for epoch in range(args.iters):
            '''training'''

            sum_iter = 0.0
            average_deviation, sum_deviation = 0.0, 0.0
            d_loss, gen_t_cost, huber_t_loss = 0.0, 0.0, 0.0

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.train_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)
            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info('Training batch num {}'.format(batch_num))

            g_iters = 5
            gap = g_iters + 1
            i = 0

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):

                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                if i > 0 and i % (batch_num // 10) == 0:
                    self.lr = self.lr * 2. / 3

                    _, deviation, gen_t_cost = sess.run(
                        [self.g_train_op, self.deviation, self.gen_t_cost],
                        feed_dict=feed_dict)

                    sum_iter = sum_iter + 1
                    sum_deviation = sum_deviation + deviation
                    average_deviation = sum_deviation / sum_iter

                # if self.cell_type == 'T_LSTMCell':
                #     sess.run(self.clip_op)

                if i % (batch_num // 10) == 0:
                    logging.info('[epoch: {}, {}] deviation: {}'.format(
                        epoch, int(i // (batch_num // 10)), average_deviation))
                    logging.info(
                        'd_loss: {}, gen_t_loss: {}, hunber_t_loss: {}'.format(
                            d_loss, gen_t_cost, huber_t_loss))
                i += 1
            '''evaluation'''

            i_e, t_e, i_t, t_t = read_data.data_iterator(
                self.valid_data, self.num_steps, self.length)

            sample_t = read_data.generate_sample_t(self.batch_size, i_t, t_t)

            batch_num = len(
                list(
                    read_data.generate_batch(self.batch_size, i_e, t_e, i_t,
                                             t_t)))
            logging.info('Evaluation Batch Num {}'.format(batch_num))

            sum_iter = 0.0
            sum_deviation = 0.0
            gen_cost_ratio = []
            t_cost_ratio = []
            i = 0

            self.lr = self.learning_rate

            for e_x, e_y, t_x, t_y in read_data.generate_batch(
                    self.batch_size, i_e, t_e, i_t, t_t):
                feed_dict = {
                    self.input_e: e_x,
                    self.inputs_t: np.maximum(np.log(t_x), 0),
                    self.target_t: t_y,
                    self.targets_e: e_y,
                    self.sample_t: np.maximum(np.log(sample_t), 0)
                }

                if i > 0 and i % (batch_num // 10) == 0:
                    self.lr = self.lr * 2. / 3

                deviation, gen_t_cost, = sess.run(
                    [self.deviation, self.gen_t_cost], feed_dict=feed_dict)

                sum_iter = sum_iter + 1
                sum_deviation = sum_deviation + deviation

                if i % (batch_num // 10) == 0:
                    logging.info('{} deviation: {},  g_loss: {}'.format(
                        int(i // (batch_num // 10)), sum_deviation / sum_iter,
                        gen_t_cost))
                i += 1

        self.save_model(sess, self.logdir, args.iters)