Пример #1
0
 def _get_ppidataset(self):
     # Arrange
     train_file = os.path.join(os.path.dirname(__file__), "..", "data", "sample_train.json")
     factory = DatasetFactory().get_datasetfactory("PpiDatasetFactory")
     dataset = factory.get_dataset(train_file)
     scorer = factory.get_metric_factory().get()
     return dataset, scorer
Пример #2
0
 def _get_aimeddataset(self):
     train_file = os.path.join(os.path.dirname(__file__), "..", "data",
                               "Aimedsample.json")
     factory = DatasetFactory().get_datasetfactory("PpiAimedDatasetFactory")
     dataset = factory.get_dataset(train_file)
     scorer = factory.get_metric_factory().get()
     return dataset, scorer
Пример #3
0
    def _get_interactiondataset(self):
        train_file = os.path.join(os.path.dirname(__file__), "..", "data", "sample_classification.json")
        factory = DatasetFactory().get_datasetfactory("InteractionDatasetFactory")
        dataset = factory.get_dataset(train_file)
        scorer = factory.get_metric_factory().get()

        return dataset, scorer
Пример #4
0
def run(dataset, datajson, artefactsbase_dir, outdir,
        positives_filter_threshold):
    dataset_factory = DatasetFactory().get_datasetfactory(dataset)
    dataset = dataset_factory.get_dataset(datajson)

    # artifacts_list = [os.path.join(artefactsbase_dir, d) for d in os.listdir(artefactsbase_dir) if os.path.isdir(d)]

    InferencePipeline().run(dataset, datajson, artefactsbase_dir, outdir,
                            positives_filter_threshold)
Пример #5
0
    def test_get_datasetfactory(self):
        # Arrange
        sut = DatasetFactory()
        class_names = sut.dataset_factory_names

        # Act
        obj = sut.get_datasetfactory(class_names[0])

        # assert
        self.assertIsInstance(obj, CustomDatasetFactoryBase)
Пример #6
0
    def test_dataset_factory_names(self):
        # Arrange
        sut = DatasetFactory()
        expected_num_factories = 8

        # act
        class_names = sut.dataset_factory_names

        # assert
        self.assertEqual(len(class_names), expected_num_factories,
                         " The number of expected dataset factory classes doesnt match.. Check the number of classes that inhert from {}  and update the test if required !".format(
                             type(CustomDatasetFactoryBase)))
Пример #7
0
def run_file(dataset_name, datajson, artefactsbase_dir, outdir,
             positives_filter_threshold):
    dataset_factory = DatasetFactory().get_datasetfactory(dataset_name)
    dataset = dataset_factory.get_dataset(datajson)
    InferencePipeline().run(dataset, datajson, artefactsbase_dir, outdir,
                            positives_filter_threshold)
Пример #8
0

def run_file(dataset_name, datajson, artefactsbase_dir, outdir,
             positives_filter_threshold):
    dataset_factory = DatasetFactory().get_datasetfactory(dataset_name)
    dataset = dataset_factory.get_dataset(datajson)
    InferencePipeline().run(dataset, datajson, artefactsbase_dir, outdir,
                            positives_filter_threshold)


if "__main__" == __name__:
    parser = argparse.ArgumentParser()

    parser.add_argument("dataset",
                        help="The dataset type",
                        choices=DatasetFactory().dataset_factory_names)

    parser.add_argument("datajson", help="The json data to predict")

    parser.add_argument(
        "artefactsdir",
        help=
        "The base of artefacts dir that contains directories of model, vocab etc"
    )
    parser.add_argument("outdir", help="The output dir")

    parser.add_argument("--log-level",
                        help="Log level",
                        default="INFO",
                        choices={"INFO", "WARN", "DEBUG", "ERROR"})
    parser.add_argument("--positives-filter-threshold",
Пример #9
0
    def __call__(self, train_file, val_file, test_file=None):
        dataset_factory = DatasetFactory().get_datasetfactory(
            self.dataset_factory_name)
        scorer_factory = dataset_factory.get_metric_factory()
        scorer = scorer_factory.get()

        self.logger.info("Using objective metric {}".format(type(scorer)))

        train, val = dataset_factory.get_dataset(
            train_file), dataset_factory.get_dataset(val_file)
        with open(self.embedding_file, "r") as embedding_handle:
            builder = TrainInferenceBuilder(
                dataset=train,
                embedding_dim=self.embed_dim,
                embedding_handle=embedding_handle,
                model_dir=self.model_dir,
                output_dir=self.out_dir,
                epochs=self.epochs,
                patience_epochs=self.earlystoppingpatience,
                extra_args=self.additionalargs,
                network_factory_name=self.network_factory_name,
                results_scorer=scorer)
            train_pipeline = builder.get_trainpipeline()
            train_pipeline(train, val)

            val_results, val_actuals, val_predicted = train_pipeline(
                train, val)

        # Get binary average
        try:
            self.print_average_scores(val_actuals,
                                      val_predicted,
                                      pos_label=train.positive_label,
                                      average='binary')
        except ValueError as ev:
            self.logger.info(
                "Could not be not be binary class {}, failed with error ".
                format(ev))

        self.print_average_scores(val_actuals,
                                  val_predicted,
                                  pos_label=train.positive_label,
                                  average='micro')

        self.print_average_scores(val_actuals,
                                  val_predicted,
                                  pos_label=train.positive_label,
                                  average='macro')

        confusion_matrix_results = confusion_matrix(val_actuals,
                                                    val_predicted).ravel()

        self.logger.info("Confusion matrix: tn, fp, fn, tp  is {}".format(
            confusion_matrix_results))

        if test_file is not None:
            train_pipeline = builder.get_trainpipeline()

            self.predict_test_set(dataset_factory, self.model_dir,
                                  self.out_dir, test_file, train_pipeline)

        return val_results, val_actuals, val_predicted
Пример #10
0
 def _get_dataset(self):
     factory = DatasetFactory().get_datasetfactory("PpiDatasetFactory")
     return factory
Пример #11
0
def run(dataset_factory_name,
        network_factory_name,
        train_file,
        val_file,
        model_dir,
        out_dir,
        epochs,
        earlystoppingpatience,
        additionalargs,
        test_file=None):
    logger = logging.getLogger(__name__)

    dataset_factory = DatasetFactory().get_datasetfactory(dataset_factory_name)

    train, val = dataset_factory.get_dataset(
        train_file), dataset_factory.get_dataset(val_file)

    scorer_factory = dataset_factory.get_metric_factory()
    scorer = scorer_factory.get()

    logger.info("Using objective metric {}".format(type(scorer)))

    if not os.path.exists(out_dir) or not os.path.isdir(out_dir):
        raise FileNotFoundError(
            "The path {} should exist and must be a directory".format(out_dir))

    if not os.path.exists(model_dir) or not os.path.isdir(model_dir):
        raise FileNotFoundError(
            "The path {} should exist and must be a directory".format(
                model_dir))

    builder = BertTrainInferenceBuilder(
        dataset=train,
        model_dir=model_dir,
        output_dir=out_dir,
        epochs=epochs,
        patience_epochs=earlystoppingpatience,
        extra_args=additionalargs,
        network_factory_name=network_factory_name,
        results_scorer=scorer)
    train_pipeline = builder.get_trainpipeline()
    val_results, val_actuals, val_predicted = train_pipeline(train, val)

    # Get binary average
    try:
        print_average_scores(val_actuals,
                             val_predicted,
                             pos_label=train.positive_label,
                             average='binary')
    except ValueError as ev:
        logger.info(
            "Could not be not be binary class {}, failed with error ".format(
                ev))

    print_average_scores(val_actuals,
                         val_predicted,
                         pos_label=train.positive_label,
                         average='micro')

    print_average_scores(val_actuals,
                         val_predicted,
                         pos_label=train.positive_label,
                         average='macro')

    confusion_matrix_results = confusion_matrix(val_actuals,
                                                val_predicted).ravel()

    logger.info("Confusion matrix: tn, fp, fn, tp  is {}".format(
        confusion_matrix_results))

    if test_file is not None:
        predict_test_set(dataset_factory, model_dir, out_dir, test_file,
                         train_pipeline)

    return val_results, val_actuals, val_predicted