def test_add_and_get_gold_labels_for_category_no_dump(self):
        num_elements_to_labels = 10
        dataset_name = self.test_add_and_get_gold_labels_for_category_no_dump.__name__

        target_category = 'Autobots'
        uris_to_gold_labels_expected = generate_random_uris_and_labels(
            dataset_name, num_elements_to_labels, [target_category])
        non_target_category = 'Decepticons'
        uris_to_gold_labels = [(uri,
                                dict(
                                    labels, **{
                                        non_target_category:
                                        Label(labels=LABEL_POSITIVE,
                                              metadata={})
                                    }))
                               for uri, labels in uris_to_gold_labels_expected]

        oracle.add_gold_labels(dataset_name, uris_to_gold_labels)

        uris_to_retrieve = [
            uri for uri, labels_dict in uris_to_gold_labels_expected
        ]
        uri_to_gold_labels_found = oracle.get_gold_labels(
            dataset_name, uris_to_retrieve, target_category)

        self.assert_uri_labels_equal(uris_to_gold_labels_expected,
                                     uri_to_gold_labels_found)
        loader.clear_gold_labels_file(dataset_name)
    def test_add_and_get_gold_labels_from_dump(self):
        num_elements_to_labels = 10

        # generate gold labels for the first dataset
        dataset_name_first = self.test_add_and_get_gold_labels_from_dump.__name__ + '_a'
        uris_to_gold_labels_expected_first = generate_random_uris_and_labels(
            dataset_name_first, num_elements_to_labels,
            ['Autobots', 'Decepticons'])
        oracle.add_gold_labels(dataset_name_first,
                               uris_to_gold_labels_expected_first)

        # generate gold labels for the second dataset (the first dataset is kept in a dump file)
        dataset_name_second = self.test_add_and_get_gold_labels_from_dump.__name__ + '_b'
        uris_to_gold_labels_expected_second = generate_random_uris_and_labels(
            dataset_name_second, num_elements_to_labels,
            ['Autobots', 'Decepticons'])
        oracle.add_gold_labels(dataset_name_second,
                               uris_to_gold_labels_expected_second)

        # switch back to the first dataset and retrieve data from dump file
        uris_to_retrieve = [
            uri for uri, labels_dict in uris_to_gold_labels_expected_first
        ]
        uri_to_gold_labels_found = oracle.get_gold_labels(
            dataset_name_first, uris_to_retrieve)
        self.assert_uri_labels_equal(uris_to_gold_labels_expected_first,
                                     uri_to_gold_labels_found)

        # clean both dataset dumps
        loader.clear_gold_labels_file(dataset_name_first)
        loader.clear_gold_labels_file(dataset_name_second)
    def test_get_gold_labels_no_dump_and_no_in_memory(self):
        dataset_name = self.test_add_and_get_gold_labels_no_dump.__name__

        uris_to_retrieve = ['uri1', 'uri2']
        uris_to_gold_labels_expected = []
        uri_to_gold_labels_found = oracle.get_gold_labels(
            dataset_name, uris_to_retrieve)

        self.assert_uri_labels_equal(uris_to_gold_labels_expected,
                                     uri_to_gold_labels_found)
        loader.clear_gold_labels_file(dataset_name)
    def train_first_model(self, config: ExperimentParams):
        if orchestrator_api.workspace_exists(config.workspace_id):
            orchestrator_api.delete_workspace(config.workspace_id)

        orchestrator_api.create_workspace(
            config.workspace_id,
            config.train_dataset_name,
            dev_dataset_name=config.dev_dataset_name)
        orchestrator_api.create_new_category(config.workspace_id,
                                             config.category_name,
                                             "No description for you")

        dev_text_elements_uris = orchestrator_api.get_all_text_elements_uris(
            config.dev_dataset_name)
        dev_text_elements_and_labels = oracle_data_access_api.get_gold_labels(
            config.dev_dataset_name, dev_text_elements_uris)
        if dev_text_elements_and_labels is not None:
            orchestrator_api.set_labels(config.workspace_id,
                                        dev_text_elements_and_labels)

        random_seed = sum([ord(c) for c in config.workspace_id])
        logging.info(str(config))
        logging.info(f'random seed: {random_seed}')

        self.set_first_model_positives(config, random_seed)
        self.set_first_model_negatives(config, random_seed)

        # train first model
        logging.info(
            f'Starting first model training (model: {config.model.name})\tworkspace: {config.workspace_id}'
        )
        new_model_id = orchestrator_api.train(config.workspace_id,
                                              config.category_name,
                                              config.model,
                                              train_params=config.train_params)
        if new_model_id is None:
            raise Exception(
                f'a new model was not trained\tworkspace: {config.workspace_id}'
            )

        eval_dataset = config.test_dataset_name
        res_dict = self.evaluate(config,
                                 al=self.NO_AL,
                                 iteration=0,
                                 eval_dataset=eval_dataset)
        res_dict.update(self.generate_al_batch_dict(
            config))  # ensures AL-related keys are in the results dictionary

        logging.info(
            f'Evaluation on dataset: {eval_dataset}, iteration: 0, first model (id: {new_model_id}) '
            f'repeat: {config.repeat_id}, is: {res_dict}\t'
            f'workspace: {config.workspace_id}')

        return res_dict
    def set_first_model_positives(self, config,
                                  random_seed) -> List[TextElement]:
        """
        Choose instances by queries, regardless of their gold label.

        :param config: experiment config for this run
        :param random_seed: a seed for the Random being used for sampling
        :return: a list of TextElements
        """
        general_dataset_name = config.train_dataset_name.split('_train')[0]
        queries = self.queries_per_dataset[general_dataset_name][
            config.category_name]
        sampled_unlabeled_text_elements = []
        for query in queries:
            sampled_unlabeled_text_elements.extend(
                self.data_access.sample_unlabeled_text_elements(
                    workspace_id=config.workspace_id,
                    dataset_name=config.train_dataset_name,
                    category_name=config.category_name,
                    sample_size=self.first_model_positives_num,
                    query=query,
                    remove_duplicates=True)['results'])
            logging.info(
                f"Positive sampling, after query {query} size is {len(sampled_unlabeled_text_elements)} "
            )

        if len(sampled_unlabeled_text_elements
               ) > self.first_model_positives_num:
            random.seed(random_seed)
            sampled_unlabeled_text_elements = random.sample(
                sampled_unlabeled_text_elements,
                self.first_model_positives_num)

        sampled_uris = [t.uri for t in sampled_unlabeled_text_elements]
        sampled_uris_and_gold_labels = dict(
            oracle_data_access_api.get_gold_labels(config.train_dataset_name,
                                                   sampled_uris))
        sampled_uris_and_label = \
            [(x.uri, {config.category_name: sampled_uris_and_gold_labels[x.uri][config.category_name]})
             for x in sampled_unlabeled_text_elements]
        orchestrator_api.set_labels(config.workspace_id,
                                    sampled_uris_and_label)

        logging.info(
            f'Set the label of {len(sampled_uris_and_label)} instances sampled by queries {queries} '
            f'using the oracle for category {config.category_name}')
        logging.info(
            f"Positive sampling, returned {len(sampled_uris)} elements")

        return sampled_uris
예제 #6
0
def generate_performance_metrics_dict(config, evaluation_dataset):
    all_text_elements = orchestrator_api.get_all_text_elements(
        evaluation_dataset)
    all_text_elements_uris = [elem.uri for elem in all_text_elements]

    predicted_labels = orchestrator_api.infer(config.workspace_id,
                                              config.category_name,
                                              all_text_elements)
    uris_and_gold_labels = oracle_data_access_api.get_gold_labels(
        evaluation_dataset, all_text_elements_uris)
    gold_of_category = [
        x[1][config.category_name].labels for x in uris_and_gold_labels
    ]
    performance_metrics = evaluate_predictions(gold_of_category,
                                               predicted_labels)
    return performance_metrics
    def test_add_and_get_gold_labels_no_dump(self):
        num_elements_to_labels = 10
        dataset_name = self.test_add_and_get_gold_labels_no_dump.__name__

        uris_to_gold_labels_expected = generate_random_uris_and_labels(
            dataset_name, num_elements_to_labels, ['Autobots', 'Decepticons'])
        oracle.add_gold_labels(dataset_name, uris_to_gold_labels_expected)

        uris_to_retrieve = [
            uri for uri, labels_dict in uris_to_gold_labels_expected
        ]
        uri_to_gold_labels_found = oracle.get_gold_labels(
            dataset_name, uris_to_retrieve)

        self.assert_uri_labels_equal(uris_to_gold_labels_expected,
                                     uri_to_gold_labels_found)
        loader.clear_gold_labels_file(dataset_name)
 def get_suggested_elements_and_gold_labels(self, config, al):
     start = time.time()
     suggested_text_elements_for_labeling = \
         orchestrator_api.get_elements_to_label(config.workspace_id, config.category_name,
                                                self.active_learning_suggestions_num)
     end = time.time()
     logging.info(
         f'{len(suggested_text_elements_for_labeling)} instances '
         f'suggested by active learning strategy: {al.name} '
         f'for dataset: {config.train_dataset_name} and category: {config.category_name}.\t'
         f'runtime: {end - start}\tworkspace: {config.workspace_id}')
     uris_for_labeling = [
         elem.uri for elem in suggested_text_elements_for_labeling
     ]
     uris_and_gold_labels = oracle_data_access_api.get_gold_labels(
         config.train_dataset_name, uris_for_labeling, config.category_name)
     return suggested_text_elements_for_labeling, uris_and_gold_labels
    def set_first_model_positives(self, config, random_seed) -> List[TextElement]:
        """
        Randomly choose instances, regardless of their gold label.
        :param config: experiment config for this run
        :param random_seed: a seed for the Random being used for sampling
        :return: a list of TextElements
        """
        sample_size = self.first_model_positives_num
        sampled_elements = self.data_access.sample_text_elements(config.train_dataset_name, sample_size,
                                                                 remove_duplicates=True)['results']
        sampled_uris = [element.uri for element in sampled_elements]
        sampled_uris_with_labels = \
            oracle_data_access_api.get_gold_labels(config.train_dataset_name, sampled_uris, config.category_name)
        orchestrator_api.set_labels(config.workspace_id, sampled_uris_with_labels)

        logging.info(f'set the label of {len(sampled_uris_with_labels)} random instances with their gold label '
                     f'(can be positive or negative) for category {config.category_name}')
        return sampled_uris
    def test_add_and_get_some_gold_labels_no_dump(self):
        num_elements_to_labels = 10
        sample_ratio = 0.4
        dataset_name = self.test_add_and_get_some_gold_labels_no_dump.__name__

        uris_to_gold_labels_expected = generate_random_uris_and_labels(
            dataset_name, num_elements_to_labels, ['Autobots', 'Decepticons'])
        oracle.add_gold_labels(dataset_name, uris_to_gold_labels_expected)
        uris_to_retrieve = [
            uri for uri, labels_dict in uris_to_gold_labels_expected
        ]
        indices = random.sample(range(len(uris_to_retrieve)),
                                floor(num_elements_to_labels * sample_ratio))
        uris_to_retrieve = [uris_to_retrieve[i] for i in sorted(indices)]
        uris_to_gold_labels_expected = [
            uri_label for uri_label in uris_to_gold_labels_expected
            if uri_label[0] in uris_to_retrieve
        ]
        uri_to_gold_labels_found = oracle.get_gold_labels(
            dataset_name, uris_to_retrieve)

        self.assert_uri_labels_equal(uris_to_gold_labels_expected,
                                     uri_to_gold_labels_found)
        loader.clear_gold_labels_file(dataset_name)
예제 #11
0
POSITIVE_SAMPLE_SIZE = 10 ** 6
MIN_TERMS_NUM = 3


if __name__ == '__main__':
    datasets_categories_queries = {'trec': {'LOC': ['Where|countr.*|cit.*']}}
    data_access = data_access_factory.get_data_access()
    dataset_part = DatasetPart.DEV
    dataset_part = dataset_part.name.lower()

    results_df = pd.DataFrame(columns=['dataset', 'category', 'query', 'hit_count'])
    for dataset in datasets_categories_queries:
        dataset_for_query = dataset + '_' + dataset_part
        uris = data_access.get_all_text_elements_uris(dataset_for_query)
        gold_labels = oracle_data_access_api.get_gold_labels(dataset_name=dataset_for_query, text_element_uris=uris)
        gold_labels = dict(gold_labels)
        for category in datasets_categories_queries[dataset]:
            for query in datasets_categories_queries[dataset][category]:
                search_res = data_access.sample_text_elements(dataset_name=dataset_for_query,
                                                              sample_size=POSITIVE_SAMPLE_SIZE, query=query)
                query_results = search_res['results']
                hit_count = search_res['hit_count']

                # # filter too short texts
                # query_results = [e for e in query_results if len(e.text.split()) >= MIN_TERMS_NUM]

                logging.info(f'dataset: {dataset_for_query}\tcategory: {category}\t'
                             f'query: {query}\thit_count: {hit_count}')
                logging.debug('\n\t' + '\n\t'.join([t.text for t in query_results]))