def test_sample_text_elements_with_labels_info(self):
        workspace_id = 'test_sample_text_elements_with_labels_info'
        dataset_name = self.test_sample_text_elements_with_labels_info.__name__ + '_dump'
        sample_all = 10**100  # a huge sample_size to sample all elements
        docs = generate_corpus(dataset_name, 5)
        # add labels info for a single doc
        selected_doc = docs[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            selected_doc, 5, ['Autobots', 'Decepticons'])
        data_access.set_labels(workspace_id, texts_and_labels_list)
        texts_and_labels_dict = dict(texts_and_labels_list)

        sampled_texts_res = data_access.sample_text_elements_with_labels_info(
            workspace_id, dataset_name, sample_all)
        for doc_text in selected_doc.text_elements:
            sampled_text = [
                sampled for sampled in sampled_texts_res['results']
                if sampled.uri == doc_text.uri
            ]
            self.assertEqual(1, len(sampled_text))
            if sampled_text[0].uri in texts_and_labels_dict:
                self.assertDictEqual(
                    sampled_text[0].category_to_label,
                    texts_and_labels_dict[sampled_text[0].uri],
                    f'for text {doc_text}')
            else:
                self.assertDictEqual(sampled_text[0].category_to_label, {},
                                     f'for text {doc_text}')
        ds_loader.clear_all_saved_files(dataset_name)
 def test_get_all_document_uris(self):
     dataset_name = self.test_get_all_document_uris.__name__ + '_dump'
     docs = generate_corpus(dataset_name, random.randint(1, 10))
     docs_uris_in_memory = data_access.get_all_document_uris(dataset_name)
     docs_uris_expected = [doc.uri for doc in docs]
     self.assertSetEqual(set(docs_uris_expected), set(docs_uris_in_memory))
     ds_loader.clear_all_saved_files(dataset_name)
    def test_get_label_counts(self):
        workspace_id = 'test_get_label_counts'
        dataset_name = self.test_get_label_counts.__name__ + '_dump'
        category = 'Decepticons'
        docs = generate_corpus(dataset_name, 2)
        # add labels info for a single doc
        selected_doc = docs[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            selected_doc, 5, ['Autobots'])
        if texts_and_labels_list:
            if category in texts_and_labels_list[0][1]:
                texts_and_labels_list[0][1][category].labels = frozenset(
                    LABEL_NEGATIVE)
            else:
                texts_and_labels_list[0][1][category] = Label(
                    labels=LABEL_NEGATIVE, metadata={})
        data_access.set_labels(workspace_id, texts_and_labels_list)

        category_label_counts = data_access.get_label_counts(
            workspace_id, dataset_name, category)
        for label_val, observed_count in category_label_counts.items():
            expected_count = len([
                t for t in texts_and_labels_list
                if category in t[1] and label_val in t[1][category].labels
            ])
            self.assertEqual(expected_count, observed_count,
                             f'count for {label_val} does not match.')
        ds_loader.clear_all_saved_files(dataset_name)
    def test_copy_workspace(self):
        ws_id = self.test_copy_workspace.__name__ + '_workspace'
        dataset_name = self.test_copy_workspace.__name__ + '_dump'
        cat_name = "cat1"
        cat_desc = "cat_desc"
        model_id = "1"
        orchestrator_api.add_documents(dataset_name,
                                       [generate_simple_doc(dataset_name)])
        orchestrator_state_api.create_workspace(workspace_id=ws_id,
                                                dataset_name=dataset_name)
        orchestrator_state_api.add_category_to_workspace(
            ws_id, cat_name, cat_desc, BINARY_LABELS)
        orchestrator_state_api.add_model(workspace_id=ws_id,
                                         category_name=cat_name,
                                         model_id=model_id,
                                         model_status=ModelStatus.READY,
                                         model_type=ModelTypes.RAND,
                                         model_metadata={})
        new_ws_id = "new_" + ws_id
        orchestrator_state_api.copy_workspace(ws_id, new_ws_id)
        old_ws = orchestrator_state_api.get_workspace(new_ws_id)
        new_ws = orchestrator_state_api.get_workspace(new_ws_id)
        self.assertEqual(new_ws.workspace_id, new_ws_id)
        self.assertEqual(old_ws.dataset_name, new_ws.dataset_name)
        self.assertEqual(old_ws.category_to_description,
                         new_ws.category_to_description)
        self.assertEqual(old_ws.category_to_model_to_recommendations,
                         new_ws.category_to_model_to_recommendations)
        self.assertEqual(old_ws.category_to_models, new_ws.category_to_models)

        orchestrator_state_api.delete_workspace_state(ws_id)
        orchestrator_state_api.delete_workspace_state(new_ws_id)
        single_dataset_loader.clear_all_saved_files(dataset_name)
    def test_sample_labeled_text_elements(self):
        workspace_id = 'test_sample_labeled_text_elements'
        dataset_name = self.test_sample_labeled_text_elements.__name__ + '_dump'
        category = 'Decepticons'
        sample_all = 10**100  # a huge sample_size to sample all elements
        docs = generate_corpus(dataset_name, 2)
        # add labels info for a single doc
        selected_doc = docs[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            selected_doc, 5, [category])
        data_access.set_labels(workspace_id, texts_and_labels_list)
        texts_and_labels_dict = dict(texts_and_labels_list)

        sampled_texts_res = data_access.sample_labeled_text_elements(
            workspace_id, dataset_name, category, sample_all)
        self.assertEqual(
            len(texts_and_labels_list), len(sampled_texts_res['results']),
            f'all and only the {len(texts_and_labels_list)} labeled elements should have been sampled.'
        )

        for sampled_text in sampled_texts_res['results']:
            self.assertIn(
                sampled_text.uri, texts_and_labels_dict.keys(),
                f'the sampled text uri - {sampled_text.uri} - was not found in the '
                f'texts that were labeled: {texts_and_labels_dict}')
            self.assertDictEqual(sampled_text.category_to_label,
                                 texts_and_labels_dict[sampled_text.uri])
        ds_loader.clear_all_saved_files(dataset_name)
def generate_corpus(dataset_name, num_of_documents=1, add_duplicate=False):
    ds_loader.clear_all_saved_files(dataset_name)
    docs = [
        generate_simple_doc(dataset_name, doc_id, add_duplicate)
        for doc_id in range(0, num_of_documents)
    ]
    data_access.add_documents(dataset_name=dataset_name, documents=docs)
    return docs
 def test_add_documents_and_get_documents(self):
     dataset_name = self.test_add_documents_and_get_documents.__name__ + '_dump'
     doc = generate_corpus(dataset_name)[0]
     doc_in_memory = data_access.get_documents(dataset_name, [doc.uri])[0]
     # compare all fields
     diffs = [(field, getattr(doc_in_memory, field), getattr(doc, field))
              for field in Document.__annotations__
              if not getattr(doc_in_memory, field) == getattr(doc, field)]
     self.assertEqual(0, len(diffs))
     ds_loader.clear_all_saved_files(dataset_name)
 def test_get_all_text_elements(self):
     dataset_name = self.test_get_all_text_elements.__name__ + '_dump'
     docs = generate_corpus(dataset_name, random.randint(1, 10))
     text_elements_found = data_access.get_all_text_elements(dataset_name)
     text_elements_found.sort(key=lambda t: t.uri)
     text_elements_expected = [
         text for doc in docs for text in doc.text_elements
     ]
     text_elements_expected.sort(key=lambda t: t.uri)
     self.assertListEqual(text_elements_expected, text_elements_found)
     ds_loader.clear_all_saved_files(dataset_name)
def load(dataset: str, force_new: bool = False):
    for part in DatasetPart:
        dataset_name = dataset + '_' + part.name.lower()
        # load dataset (generate Documents and TextElements)
        if force_new:
            single_dataset_loader.clear_all_saved_files(dataset_name)
        single_dataset_loader.load_dataset(dataset_name, force_new)
        # load gold labels
        if force_new:
            gold_labels_loader.clear_gold_labels_file(dataset_name)
        gold_labels_loader.load_gold_labels(dataset_name, force_new)
        logging.info('-' * 60)
    def test_set_train_param(self):

        ws_id = self.test_set_train_param.__name__ + '_workspace'
        dataset_name = "None"
        orchestrator_api.add_documents(dataset_name,
                                       [generate_simple_doc(dataset_name)])
        orchestrator_state_api.create_workspace(workspace_id=ws_id,
                                                dataset_name=dataset_name)
        orchestrator_state_api.add_train_param(ws_id, "key", "value")
        self.assertEqual(
            orchestrator_state_api.get_workspace(ws_id).train_params["key"],
            "value")

        orchestrator_state_api.delete_workspace_state(ws_id)
        single_dataset_loader.clear_all_saved_files(dataset_name)
    def test_unset_labels(self):
        workspace_id = 'test_unset_labels'
        dataset_name = self.test_set_labels_and_get_documents_with_labels_info.__name__ + '_dump'
        category = "cat1"
        doc = generate_corpus(dataset_name)[0]
        texts_and_labels_list = add_labels_to_doc(doc, category)
        data_access.set_labels(workspace_id, texts_and_labels_list)

        labels_count = data_access.get_label_counts(workspace_id, dataset_name,
                                                    category)
        self.assertGreater(labels_count['true'], 0)
        data_access.unset_labels(workspace_id, category,
                                 [x[0] for x in texts_and_labels_list])
        labels_count_after_unset = data_access.get_label_counts(
            workspace_id, dataset_name, category)
        self.assertEqual(0, labels_count_after_unset["true"])
        ds_loader.clear_all_saved_files(dataset_name)
    def test_sample_unlabeled_text_elements(self):
        workspace_id = 'test_sample_unlabeled_text_elements'
        dataset_name = self.test_sample_unlabeled_text_elements.__name__ + '_dump'
        category = 'Autobots'
        sample_all = 10**100  # a huge sample_size to sample all elements
        docs = generate_corpus(dataset_name, 2)
        # add labels info for a single doc
        selected_doc = docs[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            selected_doc, 5, [category])
        data_access.set_labels(workspace_id, texts_and_labels_list)

        sampled_texts_res = data_access.sample_unlabeled_text_elements(
            workspace_id, dataset_name, category, sample_all)
        for sampled_text in sampled_texts_res['results']:
            self.assertDictEqual(sampled_text.category_to_label, {})
        ds_loader.clear_all_saved_files(dataset_name)
    def test_set_labels_and_get_documents_with_labels_info(self):
        workspace_id = 'test_set_labels'
        dataset_name = self.test_set_labels_and_get_documents_with_labels_info.__name__ + '_dump'
        categories = ['cat_' + str(i) for i in range(3)]
        doc = generate_corpus(dataset_name)[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            doc, 5, categories)  # [(uri, {category: Label})]
        data_access.set_labels(workspace_id, texts_and_labels_list)

        doc_with_labels_info = data_access.get_documents_with_labels_info(
            workspace_id, dataset_name, [doc.uri])
        texts_and_labels_dict = dict(texts_and_labels_list)
        for text in doc_with_labels_info[0].text_elements:
            if text.uri in texts_and_labels_dict:
                self.assertDictEqual(text.category_to_label,
                                     texts_and_labels_dict[text.uri])
            else:
                self.assertDictEqual(text.category_to_label, {})
        ds_loader.clear_all_saved_files(dataset_name)
    def test_sample_by_query_text_elements(self):
        workspace_id = 'test_sample_by_query_text_elements'
        dataset_name = self.test_sample_by_query_text_elements.__name__ + '_dump'
        category = 'Autobots'
        query = 'sentence'
        sample_all = 10**100  # a huge sample_size to sample all elements
        doc = generate_corpus(dataset_name, 1)[0]
        # doc's elements = ['Document Title is Super Interesting', 'First sentence is not that attractive.',
        #          'The second one is a bit better.', 'Last sentence offers a promising view for the future!']
        # add labels info for a single doc
        texts_and_labels_list = [
            # 1st sent does not match query
            (doc.text_elements[0].uri, {
                category: Label(labels=LABEL_POSITIVE, metadata={})
            }),
            # 2nd sent does match query
            (doc.text_elements[1].uri, {
                category: Label(labels=LABEL_POSITIVE, metadata={})
            })
        ]
        data_access.set_labels(workspace_id, texts_and_labels_list)

        # query + unlabeled elements
        sampled_texts_res = data_access.sample_unlabeled_text_elements(
            workspace_id, dataset_name, category, sample_all, query)
        for sampled_text in sampled_texts_res['results']:
            self.assertDictEqual(sampled_text.category_to_label, {})

        # query + labeled elements
        sampled_texts_res = data_access.sample_labeled_text_elements(
            workspace_id, dataset_name, category, sample_all, query)
        self.assertEqual(
            1, len(sampled_texts_res['results']),
            f'all and only the {len(texts_and_labels_list)} labeled elements should have been sampled.'
        )
        texts_and_labels_dict = dict(texts_and_labels_list)
        for sampled_text in sampled_texts_res['results']:
            self.assertIn(
                sampled_text.uri, texts_and_labels_dict.keys(),
                f'the sampled text uri - {sampled_text.uri} - was not found in the '
                f'texts that were labeled: {texts_and_labels_dict}')
            self.assertIn(query, sampled_text.text)
        ds_loader.clear_all_saved_files(dataset_name)
    def test_get_text_elements_by_id(self):
        workspace_id = "test_get_text_elements_by_id"
        dataset_name = self.test_get_text_elements_by_id.__name__ + '_dump'
        categories = ['cat_' + str(i) for i in range(3)]
        docs = generate_corpus(dataset_name, 2)
        doc = docs[0]
        texts_and_labels_list = generate_random_texts_and_labels(
            doc, 5, categories)  # [(uri, {category: Label})]
        uri_to_labels = dict(texts_and_labels_list)
        data_access.set_labels(workspace_id, texts_and_labels_list)
        uris = [x.uri for doc in docs for x in doc.text_elements][0:2]
        all_elements = data_access.get_text_elements_with_labels_info(
            workspace_id, dataset_name, uris)

        self.assertEqual(len(uris), len(all_elements))
        self.assertEqual(uri_to_labels[all_elements[0].uri],
                         all_elements[0].category_to_label)
        self.assertEqual(uri_to_labels[all_elements[1].uri],
                         all_elements[1].category_to_label)
        ds_loader.clear_all_saved_files(dataset_name)
Esempio n. 16
0
    def test_copy_existing_workspace_with_labeled_data(self):
        try:
            workspace_id = "wd_id"
            new_workspace_id = "new_" + workspace_id
            orchestrator_api.delete_workspace(workspace_id, ignore_errors=True)
            orchestrator_api.delete_workspace(workspace_id, ignore_errors=True)
            dataset_name = "ds_name"
            cat_name = "cat_name"
            cat_desc = "cat_desc"
            document = generate_simple_doc(dataset_name)
            orchestrator_api.add_documents(dataset_name, [document])
            orchestrator_api.create_workspace(workspace_id, dataset_name)
            orchestrator_api.create_new_category(workspace_id, cat_name, cat_desc)
            # List[(str,mapping(str,Label))]
            uri1 = document.text_elements[0].uri
            uri2 = document.text_elements[1].uri
            labels = [(uri1, {cat_name: Label(LABEL_POSITIVE, {})}), (uri2, {cat_name: Label(LABEL_NEGATIVE, {})})]
            orchestrator_api.set_labels(workspace_id, labels)

            orchestrator_api.copy_workspace(workspace_id, new_workspace_id)
            results_original = orchestrator_api.query(workspace_id=workspace_id, dataset_name=dataset_name,
                                                      category_name=cat_name,
                                                      query="with label", unlabeled_only=False, sample_size=10)
            results_new = orchestrator_api.query(workspace_id=new_workspace_id, dataset_name=dataset_name,
                                                 category_name=cat_name,
                                                 query="with label", unlabeled_only=False, sample_size=10)

            self.assertEqual(results_original["results"], results_new["results"])

            labels = [(uri1, {cat_name: Label(LABEL_NEGATIVE, {})}), (uri2, {cat_name: Label(LABEL_POSITIVE, {})})]
            orchestrator_api.set_labels(new_workspace_id, labels)
            results_new = orchestrator_api.query(workspace_id=new_workspace_id, dataset_name=dataset_name,
                                                 category_name=cat_name,
                                                 query="with label", unlabeled_only=False, sample_size=10)
            self.assertNotEqual(results_original["results"], results_new["results"])
        finally:
            orchestrator_api.delete_workspace(workspace_id, ignore_errors=True)
            orchestrator_api.delete_workspace(new_workspace_id, ignore_errors=True)
            single_dataset_loader.clear_all_saved_files(dataset_name)
    def test_duplicates_removal(self):
        workspace_id = 'test_duplicates_removal'
        dataset_name = self.test_duplicates_removal.__name__ + '_dump'
        generate_corpus(dataset_name, 1, add_duplicate=True)
        all_elements = data_access.get_all_text_elements(dataset_name)
        all_elements2 = data_access.sample_text_elements(
            dataset_name, 10**6, remove_duplicates=False)['results']
        self.assertListEqual(all_elements, all_elements2)
        all_without_dups = data_access.sample_text_elements(
            dataset_name, 10**6, remove_duplicates=True)['results']
        self.assertEqual(len(all_elements), len(all_without_dups) + 1)

        category = 'cat1'
        texts_and_labels_list = [(elem.uri, {
            category:
            Label(labels=LABEL_POSITIVE, metadata={})
        }) for elem in all_without_dups]
        # set labels without propagating to duplicates
        data_access.set_labels(workspace_id,
                               texts_and_labels_list,
                               propagate_to_duplicates=False)
        labels_count = data_access.get_label_counts(workspace_id, dataset_name,
                                                    category)
        self.assertEqual(labels_count[LABEL_POSITIVE], len(all_without_dups))
        # unset labels
        data_access.unset_labels(workspace_id, category,
                                 [elem.uri for elem in all_without_dups])
        labels_count = data_access.get_label_counts(workspace_id, dataset_name,
                                                    category)
        self.assertEqual(labels_count[LABEL_POSITIVE], 0)
        # set labels with propagating to duplicates
        data_access.set_labels(workspace_id,
                               texts_and_labels_list,
                               propagate_to_duplicates=True)
        labels_count = data_access.get_label_counts(workspace_id, dataset_name,
                                                    category)
        self.assertEqual(labels_count[LABEL_POSITIVE], len(all_elements))
        ds_loader.clear_all_saved_files(dataset_name)
    def test_sample_text_elements(self):
        dataset_name = self.test_sample_text_elements.__name__ + '_dump'
        sample_size = 5
        generate_corpus(dataset_name, 10)
        sampled_texts_res = data_access.sample_text_elements(
            dataset_name, sample_size)
        self.assertEqual(sample_size, len(sampled_texts_res['results']))

        sample_all = 10**100  # a huge sample_size to sample all elements
        sampled_texts_res = data_access.sample_text_elements(
            dataset_name, sample_all)
        self.assertEqual(
            sampled_texts_res['hit_count'], len(sampled_texts_res['results']),
            f'the number of sampled elements does not equal to the hit count, '
            f'even though asked to sample all.')
        self.assertEqual(
            len(data_access.get_all_text_elements_uris(dataset_name)),
            sampled_texts_res['hit_count'],
            f'the hit count does not equal to the total number of element uris in the dataset, '
            f'even though asked to sample all.')
        # assert no labels were added
        self.assertDictEqual(sampled_texts_res['results'][0].category_to_label,
                             {})
        ds_loader.clear_all_saved_files(dataset_name)