def train(self, x_train, y_train, x_val, y_val): x_train = get_bert_param_lists(x_train) x_val = get_bert_param_lists(x_val) y_train = np.asarray(y_train) y_val = np.asarray(y_val) data_len = len(y_train) step_sum = (int((data_len - 1) / config.batch_size) + 1) * config.epochs_num best_acc_val = 0 cur_step = 0 last_improved_step = 0 adjust_num = 0 flag = True saver = tf.train.Saver(max_to_keep=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for step in range(config.epochs_num): for batch_x_train, batch_y_train in bert_bacth_iter(x_train, y_train, config.batch_size): input_ids = batch_x_train[0] input_masks = batch_x_train[1] segment_ids = batch_x_train[2] feed_dict = { self.input_ids: input_ids, self.input_masks: input_masks, self.segment_ids: segment_ids, self.labels: batch_y_train, self.is_training: True, } cur_step += 1 fetches = [self.train_op, self.global_step] sess.run(fetches, feed_dict=feed_dict) if cur_step % config.print_per_batch == 0: fetches = [self.loss, self.accuracy] loss_train, acc_train = sess.run(fetches, feed_dict=feed_dict) loss_val, acc_val = self.evaluate(sess, x_val, y_val) if acc_val >= best_acc_val: best_acc_val = acc_val last_improved_step = cur_step saver.save(sess, model_config.model_save_path) improved_str = '*' else: improved_str = '' cur_step_str = str(cur_step) + "/" + str(step_sum) msg = 'the Current step: {0}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, {5}' print(msg.format(cur_step_str, loss_train, acc_train, loss_val, acc_val, improved_str)) if cur_step - last_improved_step >= config.improvement_step: last_improved_step = cur_step print("No optimization for a long time, auto adjust learning_rate...") # learning_rate = learning_rate_decay(learning_rate) adjust_num += 1 if adjust_num > 3: print("No optimization for a long time, auto-stopping...") flag = False if not flag: break if not flag: break
def evaluate(self, sess, x_val, y_val): """评估在某一数据上的准确率和损失""" data_len = len(y_val) total_loss = 0.0 total_acc = 0.0 for batch_x_val, batch_y_val in bert_bacth_iter(x_val, y_val, config.batch_size): input_ids = batch_x_val[0] input_masks = batch_x_val[1] segment_ids = batch_x_val[2] feed_dict = { self.input_ids: input_ids, self.input_masks: input_masks, self.segment_ids: segment_ids, self.labels: batch_y_val, self.is_training: False, } batch_len = len(batch_y_val) loss, acc = sess.run([self.loss, self.accuracy], feed_dict=feed_dict) total_loss += loss * batch_len total_acc += acc * batch_len return total_loss / data_len, total_acc / data_len