def _valid_step(model: tf.keras.Model, dataset: tf.data.Dataset, progress_bar: ProgressBar,
                loss_metric: tf.keras.metrics.Mean, max_train_steps: Any = -1) -> Dict:
    """ 验证步

    :param model: 验证模型
    :param dataset: 验证数据集
    :param progress_bar: 进度管理器
    :param loss_metric: 损失计算器
    :param max_train_steps: 验证步数
    :return: 验证指标
    """
    print("验证轮次")
    start_time = time.time()
    loss_metric.reset_states()
    result, targets = tf.convert_to_tensor([], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32)

    for (batch, (first_queries, second_queries, labels)) in enumerate(dataset.take(max_train_steps)):
        outputs = model(inputs=[first_queries, second_queries])

        loss = tf.keras.losses.SparseCategoricalCrossentropy()(labels, outputs)
        loss_metric(loss)

        result = tf.concat([result, outputs[:, 1]], axis=0)
        targets = tf.concat([targets, labels], axis=0)

        progress_bar(current=batch + 1, metrics=get_dict_string(data={"valid_loss": loss_metric.result()}))

    auc_score = roc_auc_score(y_true=targets, y_score=result)
    progress_bar(current=progress_bar.total, metrics=get_dict_string(
        data={"valid_loss": loss_metric.result(), "valid_auc": auc_score}
    ))

    progress_bar.done(step_time=time.time() - start_time)

    return {"valid_loss": loss_metric.result(), "valid_auc": auc_score}
Пример #2
0
def evaluate(model: tf.keras.Model, batch_size: Any, buffer_size: Any,
             record_data_path: Any, *args, **kwargs) -> Dict:
    """ 评估器

    :param model: 评估模型
    :param batch_size: batch大小
    :param buffer_size: 缓冲大小
    :param record_data_path: TFRecord数据文件路径
    :return: 评估指标
    """
    progress_bar = ProgressBar()
    loss_metric = tf.keras.metrics.Mean(name="evaluate_loss")

    dataset = load_dataset(record_path=record_data_path,
                           batch_size=batch_size,
                           buffer_size=buffer_size,
                           data_type="valid")
    steps_per_epoch = 10000 // batch_size
    progress_bar.reset(total=steps_per_epoch, num=batch_size)

    valid_metrics = _valid_step(model=model,
                                dataset=dataset,
                                progress_bar=progress_bar,
                                loss_metric=loss_metric,
                                **kwargs)

    return valid_metrics
Пример #3
0
def _valid_step(model: tf.keras.Model,
                dataset: tf.data.Dataset,
                progress_bar: ProgressBar,
                batch_size: Any,
                loss_metric: tf.keras.metrics.Mean,
                max_train_steps: Any = -1) -> Dict:
    """ 验证步

    :param model: 验证模型
    :param dataset: 验证数据集
    :param progress_bar: 进度管理器
    :param batch_size: batch大小
    :param loss_metric: 损失计算器
    :param max_train_steps: 验证步数
    :return: 验证指标
    """
    print("验证轮次")
    start_time = time.time()
    loss_metric.reset_states()
    result, targets = tf.convert_to_tensor(
        [], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32)

    for (batch, (queries, _, true_outputs,
                 labels)) in enumerate(dataset.take(max_train_steps)):
        outputs = model(inputs=queries)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE)(true_outputs, outputs)
        mask = tf.cast(x=tf.math.not_equal(true_outputs, 0), dtype=tf.float32)
        batch_loss = tf.reduce_sum(mask * loss) / batch_size

        loss_metric(batch_loss)

        result = tf.concat(
            [result,
             tf.nn.softmax(logits=outputs[:, 0, 5:7], axis=-1)[:, 1]],
            axis=0)
        targets = tf.concat([targets, labels], axis=0)

        progress_bar(
            current=batch + 1,
            metrics=get_dict_string(data={"valid_loss": loss_metric.result()}))

    auc_score = roc_auc_score(y_true=targets, y_score=result)
    progress_bar(current=progress_bar.total,
                 metrics=get_dict_string(data={
                     "valid_loss": loss_metric.result(),
                     "valid_auc": auc_score
                 }))

    progress_bar.done(step_time=time.time() - start_time)

    return {"valid_loss": loss_metric.result(), "valid_auc": auc_score}
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any,
          epochs: Any, embedding_dim: Any, train_data_path: AnyStr, valid_data_path: AnyStr,
          max_sentence_length=Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict:
    """ 训练器

    :param model: 训练模型
    :param checkpoint: 检查点管理器
    :param batch_size: batch 大小
    :param buffer_size: 缓冲大小
    :param epochs: 训练周期
    :param embedding_dim: 词嵌入大小
    :param train_data_path: 训练数据文件保存路径
    :param valid_data_path: 验证数据文件保存路径
    :param max_sentence_length: 最大句子长度
    :param max_train_steps: 最大训练数据量,-1为全部
    :param checkpoint_save_freq: 检查点保存频率
    :return:
    """
    print("训练开始,正在准备数据中")
    # learning_rate = CustomSchedule(d_model=embedding_dim)
    loss_metric = tf.keras.metrics.Mean(name="train_loss_metric")
    optimizer = tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.98, name="optimizer")

    train_dataset = load_pair_dataset(data_path=train_data_path, max_sentence_length=max_sentence_length,
                                      batch_size=batch_size, buffer_size=buffer_size, data_type="train")
    valid_dataset = load_pair_dataset(data_path=valid_data_path, max_sentence_length=max_sentence_length,
                                      batch_size=batch_size, buffer_size=buffer_size, data_type="valid")
    train_steps_per_epoch = max_train_steps if max_train_steps != -1 else (206588 // batch_size)
    valid_steps_per_epoch = 10000 // batch_size

    progress_bar = ProgressBar()
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch + 1, epochs))
        start_time = time.time()
        loss_metric.reset_states()
        progress_bar.reset(total=train_steps_per_epoch, num=batch_size)

        train_metric = None
        for (batch, (first_queries, second_queries, labels)) in enumerate(train_dataset.take(max_train_steps)):
            train_metric, prediction = _train_step(
                model=model, optimizer=optimizer, loss_metric=loss_metric,
                first_queries=first_queries, second_queries=second_queries, labels=labels
            )

            progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric))

        progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric))

        progress_bar.done(step_time=time.time() - start_time)

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0 or valid_dataset is None:
                print("验证数据量过小,小于batch_size,已跳过验证轮次")
            else:
                progress_bar.reset(total=valid_steps_per_epoch, num=batch_size)
                valid_metrics = _valid_step(model=model, dataset=valid_dataset,
                                            progress_bar=progress_bar, loss_metric=loss_metric, **kwargs)
    print("训练结束")
    return {}
Пример #5
0
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any,
          epochs: Any, train_data_path: AnyStr, test_data_path: AnyStr, dict_path: AnyStr, max_sentence_length: Any,
          max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict:
    """ 训练器

    :param model: 训练模型
    :param checkpoint: 检查点管理器
    :param batch_size: batch 大小
    :param buffer_size: 缓冲大小
    :param epochs: 训练周期
    :param train_data_path: 训练数据保存路径
    :param test_data_path: 测试数据保存路径
    :param dict_path: 词表文件
    :param max_sentence_length: 最大句子对长度
    :param max_train_steps: 最大训练数据量,-1为全部
    :param checkpoint_save_freq: 检查点保存频率
    :return:
    """
    print("训练开始,正在准备数据中")

    loss_metric = tf.keras.metrics.Mean(name="train_loss_metric")

    optimizer = tf.optimizers.Adam(learning_rate=1e-5)

    train_steps_per_epoch = 125000 // batch_size
    valid_steps_per_epoch = 10000 // batch_size
    warmup_steps = train_steps_per_epoch // 3
    total_steps = train_steps_per_epoch * epochs - warmup_steps

    # learning_rate = CustomSchedule(d_model=768, warmup_steps=warmup_steps)
    # optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, name="optimizer")

    # optimizer, _ = create_optimizer(init_lr=2e-5, num_train_steps=total_steps, num_warmup_steps=warmup_steps)

    train_dataset, valid_dataset = load_raw_dataset(
        train_data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size,
        buffer_size=buffer_size, dict_path=dict_path, test_data_path=test_data_path)

    print("训练开始")
    progress_bar = ProgressBar()
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch + 1, epochs))
        start_time = time.time()
        loss_metric.reset_states()
        progress_bar.reset(total=train_steps_per_epoch, num=batch_size)

        train_metric = None
        for (batch, (queries, segments, labels)) in enumerate(train_dataset.take(max_train_steps)):
            train_metric = _train_step(
                model=model, optimizer=optimizer, loss_metric=loss_metric,
                queries=queries, segments=segments, targets=labels
            )
            progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric))

        progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric))

        progress_bar.done(step_time=time.time() - start_time)

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0 or valid_dataset is None:
                print("验证数据量过小,小于batch_size,已跳过验证轮次")
            else:
                progress_bar.reset(total=valid_steps_per_epoch, num=batch_size)
                valid_metrics = _valid_step(model=model, dataset=valid_dataset,
                                            progress_bar=progress_bar, loss_metric=loss_metric, **kwargs)
    print("训练结束")
    return {}
Пример #6
0
def train(model: tf.keras.Model,
          checkpoint: tf.train.CheckpointManager,
          batch_size: Any,
          buffer_size: Any,
          epochs: Any,
          embedding_dim: Any,
          train_data_path: AnyStr,
          valid_data_path: AnyStr,
          max_sentence_length: Any,
          max_train_steps: Any = -1,
          checkpoint_save_freq: Any = 2,
          *args,
          **kwargs) -> Dict:
    """ 训练器

    :param model: 训练模型
    :param checkpoint: 检查点管理器
    :param batch_size: batch 大小
    :param buffer_size: 缓冲大小
    :param epochs: 训练周期
    :param embedding_dim: 词嵌入大小
    :param train_data_path: 训练数据保存路径
    :param valid_data_path: 验证数据保存路径
    :param max_sentence_length: 最大句子对长度
    :param max_train_steps: 最大训练数据量,-1为全部
    :param checkpoint_save_freq: 检查点保存频率
    :return:
    """
    print("训练开始,正在准备数据中")
    learning_rate = CustomSchedule(d_model=embedding_dim)
    loss_metric = tf.keras.metrics.Mean(name="train_loss_metric")
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate,
                                   beta_1=0.9,
                                   beta_2=0.98,
                                   name="optimizer")

    train_dataset = load_raw_dataset(data_path=train_data_path,
                                     max_sentence_length=max_sentence_length,
                                     batch_size=batch_size,
                                     buffer_size=buffer_size,
                                     data_type="train")
    valid_dataset = load_raw_dataset(data_path=valid_data_path,
                                     max_sentence_length=max_sentence_length,
                                     batch_size=batch_size,
                                     buffer_size=buffer_size,
                                     data_type="valid")
    train_steps_per_epoch = max_train_steps if max_train_steps != -1 else (
        90000 // batch_size)
    valid_steps_per_epoch = 10000 // batch_size

    progress_bar = ProgressBar()
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch + 1, epochs))
        start_time = time.time()
        loss_metric.reset_states()
        progress_bar.reset(total=train_steps_per_epoch, num=batch_size)

        train_metric = None
        result, targets = tf.convert_to_tensor(
            [], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32)
        for (batch,
             (queries, _, outputs,
              labels)) in enumerate(train_dataset.take(max_train_steps)):
            train_metric, prediction = _train_step(model=model,
                                                   optimizer=optimizer,
                                                   batch_size=batch_size,
                                                   loss_metric=loss_metric,
                                                   queries=queries,
                                                   targets=outputs)
            result = tf.concat([result, prediction[:, 1]], axis=0)
            targets = tf.concat([targets, labels], axis=0)
            progress_bar(current=batch + 1,
                         metrics=get_dict_string(data=train_metric))

        auc_score = roc_auc_score(y_true=targets, y_score=result)
        train_metric["train_auc"] = auc_score
        progress_bar(current=progress_bar.total,
                     metrics=get_dict_string(data=train_metric))

        progress_bar.done(step_time=time.time() - start_time)

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0 or valid_dataset is None:
                print("验证数据量过小,小于batch_size,已跳过验证轮次")
            else:
                progress_bar.reset(total=valid_steps_per_epoch, num=batch_size)
                valid_metrics = _valid_step(model=model,
                                            dataset=valid_dataset,
                                            batch_size=batch_size,
                                            progress_bar=progress_bar,
                                            loss_metric=loss_metric,
                                            **kwargs)
    print("训练结束")
    return {}