def _valid_step(model: tf.keras.Model, dataset: tf.data.Dataset, progress_bar: ProgressBar, loss_metric: tf.keras.metrics.Mean, max_train_steps: Any = -1) -> Dict: """ 验证步 :param model: 验证模型 :param dataset: 验证数据集 :param progress_bar: 进度管理器 :param loss_metric: 损失计算器 :param max_train_steps: 验证步数 :return: 验证指标 """ print("验证轮次") start_time = time.time() loss_metric.reset_states() result, targets = tf.convert_to_tensor([], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32) for (batch, (first_queries, second_queries, labels)) in enumerate(dataset.take(max_train_steps)): outputs = model(inputs=[first_queries, second_queries]) loss = tf.keras.losses.SparseCategoricalCrossentropy()(labels, outputs) loss_metric(loss) result = tf.concat([result, outputs[:, 1]], axis=0) targets = tf.concat([targets, labels], axis=0) progress_bar(current=batch + 1, metrics=get_dict_string(data={"valid_loss": loss_metric.result()})) auc_score = roc_auc_score(y_true=targets, y_score=result) progress_bar(current=progress_bar.total, metrics=get_dict_string( data={"valid_loss": loss_metric.result(), "valid_auc": auc_score} )) progress_bar.done(step_time=time.time() - start_time) return {"valid_loss": loss_metric.result(), "valid_auc": auc_score}
def _valid_step(model: tf.keras.Model, dataset: tf.data.Dataset, progress_bar: ProgressBar, batch_size: Any, loss_metric: tf.keras.metrics.Mean, max_train_steps: Any = -1) -> Dict: """ 验证步 :param model: 验证模型 :param dataset: 验证数据集 :param progress_bar: 进度管理器 :param batch_size: batch大小 :param loss_metric: 损失计算器 :param max_train_steps: 验证步数 :return: 验证指标 """ print("验证轮次") start_time = time.time() loss_metric.reset_states() result, targets = tf.convert_to_tensor( [], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32) for (batch, (queries, _, true_outputs, labels)) in enumerate(dataset.take(max_train_steps)): outputs = model(inputs=queries) loss = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.NONE)(true_outputs, outputs) mask = tf.cast(x=tf.math.not_equal(true_outputs, 0), dtype=tf.float32) batch_loss = tf.reduce_sum(mask * loss) / batch_size loss_metric(batch_loss) result = tf.concat( [result, tf.nn.softmax(logits=outputs[:, 0, 5:7], axis=-1)[:, 1]], axis=0) targets = tf.concat([targets, labels], axis=0) progress_bar( current=batch + 1, metrics=get_dict_string(data={"valid_loss": loss_metric.result()})) auc_score = roc_auc_score(y_true=targets, y_score=result) progress_bar(current=progress_bar.total, metrics=get_dict_string(data={ "valid_loss": loss_metric.result(), "valid_auc": auc_score })) progress_bar.done(step_time=time.time() - start_time) return {"valid_loss": loss_metric.result(), "valid_auc": auc_score}
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, embedding_dim: Any, train_data_path: AnyStr, valid_data_path: AnyStr, max_sentence_length=Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param embedding_dim: 词嵌入大小 :param train_data_path: 训练数据文件保存路径 :param valid_data_path: 验证数据文件保存路径 :param max_sentence_length: 最大句子长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") # learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.98, name="optimizer") train_dataset = load_pair_dataset(data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="train") valid_dataset = load_pair_dataset(data_path=valid_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="valid") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else (206588 // batch_size) valid_steps_per_epoch = 10000 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (first_queries, second_queries, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step( model=model, optimizer=optimizer, loss_metric=loss_metric, first_queries=first_queries, second_queries=second_queries, labels=labels ) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, train_data_path: AnyStr, test_data_path: AnyStr, dict_path: AnyStr, max_sentence_length: Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param train_data_path: 训练数据保存路径 :param test_data_path: 测试数据保存路径 :param dict_path: 词表文件 :param max_sentence_length: 最大句子对长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=1e-5) train_steps_per_epoch = 125000 // batch_size valid_steps_per_epoch = 10000 // batch_size warmup_steps = train_steps_per_epoch // 3 total_steps = train_steps_per_epoch * epochs - warmup_steps # learning_rate = CustomSchedule(d_model=768, warmup_steps=warmup_steps) # optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, name="optimizer") # optimizer, _ = create_optimizer(init_lr=2e-5, num_train_steps=total_steps, num_warmup_steps=warmup_steps) train_dataset, valid_dataset = load_raw_dataset( train_data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, dict_path=dict_path, test_data_path=test_data_path) print("训练开始") progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (queries, segments, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric = _train_step( model=model, optimizer=optimizer, loss_metric=loss_metric, queries=queries, segments=segments, targets=labels ) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, embedding_dim: Any, train_data_path: AnyStr, valid_data_path: AnyStr, max_sentence_length: Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param embedding_dim: 词嵌入大小 :param train_data_path: 训练数据保存路径 :param valid_data_path: 验证数据保存路径 :param max_sentence_length: 最大句子对长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, name="optimizer") train_dataset = load_raw_dataset(data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="train") valid_dataset = load_raw_dataset(data_path=valid_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="valid") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else ( 90000 // batch_size) valid_steps_per_epoch = 10000 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None result, targets = tf.convert_to_tensor( [], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32) for (batch, (queries, _, outputs, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step(model=model, optimizer=optimizer, batch_size=batch_size, loss_metric=loss_metric, queries=queries, targets=outputs) result = tf.concat([result, prediction[:, 1]], axis=0) targets = tf.concat([targets, labels], axis=0) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) auc_score = roc_auc_score(y_true=targets, y_score=result) train_metric["train_auc"] = auc_score progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, batch_size=batch_size, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}