Ejemplo n.º 1
0
    def train(self, x_train, y_train, x_val, y_val):
        x_train = get_bert_param_lists(x_train)
        x_val = get_bert_param_lists(x_val)
        y_train = np.asarray(y_train)
        y_val = np.asarray(y_val)

        data_len = len(y_train)
        step_sum = (int((data_len - 1) / config.batch_size) + 1) * config.epochs_num
        best_acc_val = 0
        cur_step = 0
        last_improved_step = 0
        adjust_num = 0
        flag = True
        saver = tf.train.Saver(max_to_keep=1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for step in range(config.epochs_num):
                for batch_x_train, batch_y_train in bert_bacth_iter(x_train, y_train, config.batch_size):
                    input_ids = batch_x_train[0]
                    input_masks = batch_x_train[1]
                    segment_ids = batch_x_train[2]
                    feed_dict = {
                        self.input_ids: input_ids,
                        self.input_masks: input_masks,
                        self.segment_ids: segment_ids,
                        self.labels: batch_y_train,
                        self.is_training: True,
                    }
                    cur_step += 1
                    fetches = [self.train_op, self.global_step]
                    sess.run(fetches, feed_dict=feed_dict)
                    if cur_step % config.print_per_batch == 0:
                        fetches = [self.loss, self.accuracy]
                        loss_train, acc_train = sess.run(fetches, feed_dict=feed_dict)
                        loss_val, acc_val = self.evaluate(sess, x_val, y_val)
                        if acc_val >= best_acc_val:
                            best_acc_val = acc_val
                            last_improved_step = cur_step
                            saver.save(sess, model_config.model_save_path)
                            improved_str = '*'
                        else:
                            improved_str = ''
                        cur_step_str = str(cur_step) + "/" + str(step_sum)
                        msg = 'the Current step: {0}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                              + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, {5}'
                        print(msg.format(cur_step_str, loss_train, acc_train, loss_val, acc_val, improved_str))
                    if cur_step - last_improved_step >= config.improvement_step:
                        last_improved_step = cur_step
                        print("No optimization for a long time, auto adjust learning_rate...")
                        # learning_rate = learning_rate_decay(learning_rate)
                        adjust_num += 1
                        if adjust_num > 3:
                            print("No optimization for a long time, auto-stopping...")
                            flag = False
                    if not flag:
                        break
                if not flag:
                    break
Ejemplo n.º 2
0
 def evaluate(self, sess, x_val, y_val):
     """评估在某一数据上的准确率和损失"""
     data_len = len(y_val)
     total_loss = 0.0
     total_acc = 0.0
     for batch_x_val, batch_y_val in bert_bacth_iter(x_val, y_val, config.batch_size):
         input_ids = batch_x_val[0]
         input_masks = batch_x_val[1]
         segment_ids = batch_x_val[2]
         feed_dict = {
             self.input_ids: input_ids,
             self.input_masks: input_masks,
             self.segment_ids: segment_ids,
             self.labels: batch_y_val,
             self.is_training: False,
         }
         batch_len = len(batch_y_val)
         loss, acc = sess.run([self.loss, self.accuracy], feed_dict=feed_dict)
         total_loss += loss * batch_len
         total_acc += acc * batch_len
     return total_loss / data_len, total_acc / data_len