예제 #1
0
def avg_checkpints(hparams: Hparams):
    logger = logging.getLogger(__name__)
    (model,) = build_model(hparams, return_losses=False, return_metrics=False, return_optimizer=False)
    logger.info(f"Average checkpoints from {hparams.prefix_or_checkpints}")
    average_checkpoints(model, hparams.prefix_or_checkpints, hparams.num_last_checkpoints, hparams.ckpt_weights)
    evaluation(hparams, model=model)
    logger.info(f"Save model in {hparams.get_model_filename()}")
    model.save_weights(hparams.get_model_filename(), save_format="tf")
예제 #2
0
def evaluation(hparams: Hparams,
               checkpoints=None,
               model=None,
               test_dataset=None):
    """Evaluate the model and build report according to different task.

    :param model:
    :param test_dataset:
    :param hparams:
    :return:
    """
    logger.info("Start Evaluate.")
    output_hparams = deepcopy(hparams.dataset.outputs)

    if test_dataset is None:
        test_dataset = next(
            load_dataset(hparams,
                         ret_train=False,
                         ret_dev=False,
                         ret_info=False))[0]

    if model is None:
        # build model
        (model, ) = build_model(hparams,
                                return_losses=False,
                                return_metrics=False,
                                return_optimizer=False)

    # predict using default model saved
    if checkpoints is None:
        # load weights
        if not os.path.exists(hparams.get_model_filename() + ".index"):
            logger.warning(
                f"Model from {hparams.get_model_filename()} is not exists, load nothing!"
            )
        else:
            logger.info(
                f"Load model weights from {hparams.get_model_filename()}")
            model.load_weights(hparams.get_model_filename())

        # prediction
        # print(model.evaluate(test_dataset))
        for inputs, outputs in tqdm(test_dataset):
            model_outputs = model.predict(inputs)
            if not isinstance(model_outputs, (tuple, list)):
                model_outputs = (model_outputs, )
            for idx, one_output_hparam in enumerate(output_hparams):
                if "ground_truth" not in one_output_hparam:
                    one_output_hparam["ground_truth"] = []
                if "predictions" not in one_output_hparam:
                    one_output_hparam['predictions'] = []
                prediction_output = tf.nn.softmax(model_outputs[idx], -1)
                tmp_name = one_output_hparam.name
                tmp_type = one_output_hparam.type
                tmp_ground_truth = outputs[tmp_name]
                if tmp_type in [CLASSLABEL, LIST_OF_CLASSLABEL, LIST_OF_INT]:
                    if tmp_type in [LIST_OF_INT]:
                        tmp_tg = tf.argmax(tmp_ground_truth, -1)
                    else:
                        tmp_tg = tmp_ground_truth
                    if one_output_hparam.task == NER:  # [[sent1], [sent2]]
                        one_output_hparam.ground_truth.extend(
                            tmp_tg.numpy().tolist())
                        tmp_predictions = tf.argmax(prediction_output,
                                                    -1).numpy().tolist()
                        one_output_hparam.predictions.extend(tmp_predictions)
                    else:  # [1, 0, 1, ...]
                        one_output_hparam.ground_truth.extend(
                            tmp_tg.numpy().reshape(-1).tolist())
                        tmp_predictions = tf.argmax(
                            prediction_output,
                            -1).numpy().reshape(-1).tolist()
                        one_output_hparam.predictions.extend(tmp_predictions)
    elif isinstance(checkpoints, (tuple, list)):
        # predict using multi checkpints from k-fold cross validation.
        for i, ckpt in enumerate(checkpoints):
            if not os.path.exists(ckpt + ".index"):
                logger.warning(
                    f"Model from {ckpt} is not exists, load nothing!")
                continue
            else:
                logger.info(f"Load model weights from {ckpt}")
                model.load_weights(ckpt)

            for j, (inputs, outputs) in tqdm(enumerate(test_dataset)):
                model_outputs = model.predict(inputs)
                if not isinstance(model_outputs, (tuple, list)):
                    model_outputs = (model_outputs, )
                for idx, one_output_hparam in enumerate(output_hparams):
                    prediction_output = tf.nn.softmax(model_outputs[idx], -1)
                    if i == 0:
                        if "ground_truth" not in one_output_hparam:
                            one_output_hparam["ground_truth"] = []
                        if "predictions" not in one_output_hparam:
                            one_output_hparam['predictions'] = []
                            one_output_hparam['tmp_preds'] = []
                        one_output_hparam['tmp_preds'].append(
                            prediction_output)
                        tmp_name = one_output_hparam.name
                        tmp_type = one_output_hparam.type
                        tmp_ground_truth = outputs[tmp_name]
                        if tmp_type in [
                                CLASSLABEL, LIST_OF_CLASSLABEL, LIST_OF_INT
                        ]:
                            if tmp_type in [LIST_OF_INT]:
                                tmp_tg = tf.argmax(tmp_ground_truth, -1)
                            else:
                                tmp_tg = tmp_ground_truth
                            if one_output_hparam.task == NER:  # [[sent1], [sent2]]
                                one_output_hparam.ground_truth.extend(
                                    tmp_tg.numpy().tolist())
                            else:  # [1, 0, 1, ...]
                                one_output_hparam.ground_truth.extend(
                                    tmp_tg.numpy().reshape(-1).tolist())
                    else:
                        one_output_hparam['tmp_preds'][j] += prediction_output

        for idx, one_output_hparam in enumerate(output_hparams):
            prediction_output = one_output_hparam['tmp_preds'][idx]
            tmp_type = one_output_hparam.type
            if tmp_type in [CLASSLABEL, LIST_OF_CLASSLABEL, LIST_OF_INT]:
                if one_output_hparam.task == NER:  # [[sent1], [sent2]]
                    tmp_predictions = tf.argmax(prediction_output,
                                                -1).numpy().tolist()
                    one_output_hparam.predictions.extend(tmp_predictions)
                else:  # [1, 0, 1, ...]
                    tmp_predictions = tf.argmax(
                        prediction_output, -1).numpy().reshape(-1).tolist()
                    one_output_hparam.predictions.extend(tmp_predictions)

    # save reports
    report_folder = hparams.get_report_dir()
    # evaluation, TODO more reports
    for one_output_hparam in output_hparams:
        ground_truth = one_output_hparam.ground_truth
        predictions = one_output_hparam.predictions
        if one_output_hparam.type in [
                CLASSLABEL, LIST_OF_CLASSLABEL, LIST_OF_INT
        ]:
            # some filename
            cur_report_folder = os.path.join(
                report_folder,
                f'{one_output_hparam.name}_{one_output_hparam.type.lower()}')
            if not os.path.exists(cur_report_folder):
                os.makedirs(cur_report_folder)

            if one_output_hparam.task == NER:
                labels = one_output_hparam.labels
                # confusion matrix
                cm = ConfusionMatrix(_2d_to_1d_list(ground_truth),
                                     _2d_to_1d_list(predictions), labels)
                # ner evaluation
                labels = list(
                    set([
                        itm[2:] for itm in labels
                        if itm.startswith("B-") or itm.startswith("I-")
                    ]))
                ner_eval = NEREvaluator(
                    _id_to_label(ground_truth, one_output_hparam.labels),
                    _id_to_label(predictions, one_output_hparam.labels),
                    labels)
                ner_results, ner_results_agg = ner_eval.evaluate()
                save_json(os.path.join(cur_report_folder, "ner_results.json"),
                          ner_results)
                save_json(
                    os.path.join(cur_report_folder, "ner_results_agg.json"),
                    ner_results_agg)
            else:
                cm = ConfusionMatrix(ground_truth, predictions,
                                     one_output_hparam.labels)

            # print some reports
            print_boxed(f"{one_output_hparam.name} Evaluation")

            cms = cm.confusion_matrix_visual()
            if len(cm.label2idx) < 10:
                print(cms)
                # save reports to files
                with open(
                        os.path.join(cur_report_folder,
                                     "confusion_matrix.txt"), 'w') as f:
                    f.write(cms)
            print()
            print(json.dumps(cm.stats(), indent=4))
            save_json(os.path.join(cur_report_folder, "stats.json"),
                      cm.stats())
            save_json(os.path.join(cur_report_folder, 'per_class_stats.json'),
                      cm.per_class_stats())
            # save reports to hparams
            hparams['performance'] = Hparams()
            hparams.performance["stats"] = cm.stats()
            hparams.performance["per_class_stats"] = cm.per_class_stats()
            logger.info(
                f"Save {one_output_hparam.name} reports in {cur_report_folder}"
            )
        else:
            logger.warning(
                f"{one_output_hparam.name}'s evaluation has not be implemented."
            )
예제 #3
0
def k_fold_experiment(hparams: Hparams):
    """
    k_fold training
    :param hparams:
    :return:
    """
    logger = logging.getLogger(__name__)
    if hparams.use_mixed_float16:
        logger.info("Use auto mixed policy")
        # tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

    strategy = tf.distribute.MirroredStrategy(
        devices=[f"/gpu:{id}" for id in hparams.gpus])
    # build dataset

    model_saved_dirs = []

    for idx, (train_dataset, dev_dataset,
              dataset_info) in enumerate(load_dataset(hparams,
                                                      ret_test=False)):
        logger.info(f"Start {idx}th-fold training")
        with strategy.scope():
            # build model
            model, (losses,
                    loss_weights), metrics, optimizer = build_model(hparams)
            # build callbacks
            callbacks = build_callbacks(hparams.training.callbacks)
            # compile
            model.compile(optimizer=optimizer,
                          loss=losses,
                          metrics=metrics,
                          loss_weights=loss_weights)
            # fit
            if hparams.training.do_eval:
                validation_data = dev_dataset
                validation_steps = hparams.training.validation_steps
            else:
                logger.info("Do not evaluate.")
                validation_data = None
                validation_steps = None

            model.fit(
                train_dataset,
                validation_data=validation_data,
                epochs=hparams.training.max_epochs,
                callbacks=callbacks,
                steps_per_epoch=hparams.training.steps_per_epoch,
                validation_steps=validation_steps,
            )

            # build archive dir
            k_fold_dir = os.path.join(hparams.get_workspace_dir(), "k_fold",
                                      str(idx))
            if not os.path.exists(k_fold_dir):
                os.makedirs(k_fold_dir)

            # load best model
            checkpoint_dir = os.path.join(hparams.get_workspace_dir(),
                                          "checkpoint")
            if hparams.eval_use_best and os.path.exists(checkpoint_dir):
                logger.info(f"Load best model from {checkpoint_dir}")
                average_checkpoints(model, checkpoint_dir)
                logger.info(f"Move {checkpoint_dir, k_fold_dir}")
                shutil.move(checkpoint_dir, k_fold_dir)

            # save best model
            logger.info(
                f'Save {idx}th model in {hparams.get_model_filename()}')
            model.save_weights(hparams.get_model_filename(), save_format="tf")

        # eval on test dataset and make reports
        evaluation(hparams)
        logger.info(f"Move {hparams.get_report_dir()} to {k_fold_dir}")
        shutil.move(hparams.get_report_dir(), k_fold_dir)
        logger.info(f"Move {hparams.get_saved_model_dir()} to {k_fold_dir}")
        cur_model_saved_dir = shutil.move(hparams.get_saved_model_dir(),
                                          k_fold_dir)
        logger.info(
            f"New model saved path for {idx}th fold: {cur_model_saved_dir}")
        model_saved_dirs.append(cur_model_saved_dir)

        logger.info(f'{idx}th-fold experiment Finish!')

    # eval on test dataset after average_checkpoints
    # logger.info("Average models of all fold models.")
    checkpoints = [f'{itm}/model' for itm in model_saved_dirs]
    # average_checkpoints(model, checkpoints)

    # logger.info(f"Save averaged model in {hparams.get_model_filename()}")
    # model.save_weights(hparams.get_model_filename(), save_format="tf")
    if hparams.training.do_eval:
        evaluation(hparams, checkpoints=checkpoints)

    logger.info('Experiment Finish!')
예제 #4
0
def experiment(hparams: Hparams):
    logger = logging.getLogger(__name__)
    if hparams.use_mixed_float16:
        logger.info("Use auto mixed policy")
        # tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

    strategy = tf.distribute.MirroredStrategy(
        devices=[f"/gpu:{id}" for id in hparams.gpus])
    # build dataset
    train_dataset, dev_dataset, dataset_info = next(
        load_dataset(hparams, ret_test=False))

    with strategy.scope():
        # build model
        model, (losses,
                loss_weights), metrics, optimizer = build_model(hparams)
        # build callbacks
        callbacks = build_callbacks(hparams)
        # compile
        model.compile(optimizer=optimizer,
                      loss=losses,
                      metrics=metrics,
                      loss_weights=loss_weights)
        # fit
        if hparams.training.do_eval:
            validation_data = dev_dataset
            validation_steps = hparams.training.validation_steps
        else:
            logger.info("Do not evaluate.")
            validation_data = None
            validation_steps = None

        model.fit(
            train_dataset,
            validation_data=validation_data,
            epochs=hparams.training.max_epochs,
            callbacks=callbacks,
            steps_per_epoch=hparams.training.steps_per_epoch,
            validation_steps=validation_steps,
        )

    # 进行lr finder
    lr_finder_call_back = [
        cb for cb in callbacks if hasattr(cb, "lr_finder_plot")
    ]
    if len(lr_finder_call_back) != 0:
        logger.info(
            f"Do lr finder, and save result in {hparams.get_lr_finder_jpg_file()}"
        )
        lr_finder_call_back[0].lr_finder_plot(hparams.get_lr_finder_jpg_file())
    else:
        # load best model
        checkpoint_dir = os.path.join(hparams.get_workspace_dir(),
                                      "checkpoint")
        if hparams.eval_use_best and os.path.exists(checkpoint_dir):
            logger.info(f"Load best model from {checkpoint_dir}")
            average_checkpoints(model, checkpoint_dir)
        # save best model
        logger.info(f'Save model in {hparams.get_model_filename()}')
        model.save_weights(hparams.get_model_filename(), save_format="tf")

        # eval on test dataset and make reports
        if hparams.training.do_eval:
            evaluation(hparams)

    logger.info('Experiment Finish!')