Exemplo n.º 1
0
    def train(self, epoch, draw_graph=False):

        print("Start the Cross validation for",
              self.train_collection.get_n_folds(), "folds")

        temp_dir = tempfile.mkdtemp()

        try:
            # save model init state
            save_model_weights(os.path.join(temp_dir, "temp_weights.h5"),
                               self.model)
            best_test_scores = []
            for i, collections in enumerate(self.train_collection.generator()):
                print("Prepare FOLD", i)

                train_collection, test_collection = collections

                # show baseline metrics over the previous ranking order
                pre_metrics = test_collection.evaluate_pre_rerank()
                print("Evaluation of the original ranking order")
                for n, m in pre_metrics:
                    print(n, m)

                # reset all the states
                set_random_seed()
                K.clear_session()

                # load model init state
                load_model_weights(os.path.join(temp_dir, "temp_weights.h5"),
                                   self.model)

                self.wandb_config["name"] = "Fold_0" + str(
                    i) + "_" + self.wandb_config["name"]

                # create evaluation callback
                if self.wandb_config is not None:
                    wandb_val_logger = WandBValidationLogger(
                        wandb_args=self.wandb_config,
                        steps_per_epoch=train_collection.get_steps(),
                        validation_collection=test_collection)
                else:
                    raise KeyError("Please use wandb for now!!!")

                best_test_scores.append(wandb_val_logger.current_best)

                callbacks = [wandb_val_logger] + self.callbacks

                print("Train and test FOLD", i)

                pairwise_training = PairwiseTraining(
                    model=self.model,
                    train_collection=train_collection,
                    loss=self.loss,
                    optimizer=self.optimizer,
                    callbacks=callbacks)

                pairwise_training.train(epoch, draw_graph=draw_graph)

            x_score = sum(best_test_scores) / len(best_test_scores)
            print("X validation best score:", x_score)
            wandb_val_logger.wandb.run.summary[
                "best_xval_" + wandb_val_logger.comparison_metric] = x_score

        except Exception as e:
            raise e  # maybe handle the exception in the future
        finally:
            # always remove the temp directory
            print("Remove {}".format(temp_dir))
            shutil.rmtree(temp_dir)
Exemplo n.º 2
0
from mmnrm.utils import set_random_seed
from nir.embeddings import FastText, Word2Vec

set_random_seed()

import io
from nir.tokenizers import Regex, BioCleanTokenizer, BioCleanTokenizer2
import numpy as np
import math
import os
import json

import tensorflow as tf
from tensorflow.keras import backend as K

from mmnrm.dataset import TrainCollectionV2, TestCollectionV2, sentence_splitter_builderV2, TrainPairwiseCollection
from mmnrm.modelsv2 import deep_rank
from mmnrm.callbacks import TriangularLR, WandBValidationLogger, LearningRateScheduler
from mmnrm.training import PairwiseTraining, pairwise_cross_entropy
from mmnrm.utils import merge_dicts, load_model


def main():

    min_freq = 5
    mun_itter = 15
    emb_size = 200

    use_triangularLR = False
    use_step_decay = False