예제 #1
0
파일: service.py 프로젝트: recognai/rubrix
    def read_dataset(
        self,
        dataset: str,
        owner: Optional[str],
        query: Optional[TokenClassificationQuery] = None,
    ) -> Iterable[TokenClassificationRecord]:
        """
        Scan a dataset records

        Parameters
        ----------
        dataset:
            The dataset name
        owner:
            The dataset owner. Optional
        query:
            If provided, scan will retrieve only records matching
            the provided query filters. Optional

        """
        dataset = self.__datasets__.find_by_name(dataset, owner=owner)
        for db_record in self.__dao__.scan_dataset(
            dataset, search=RecordSearch(query=as_elasticsearch(query))
        ):
            yield TokenClassificationRecord.parse_obj(db_record)
예제 #2
0
def test_too_long_metadata():
    text = "On one ones o no"
    record = TokenClassificationRecord.parse_obj({
        "text": text,
        "tokens": text.split(),
        "metadata": {
            "too_long": "a" * 1000
        },
    })

    assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
예제 #3
0
def test_fix_substrings():
    text = "On one ones o no"
    TokenClassificationRecord(
        text=text,
        tokens=text.split(),
        prediction=TokenClassificationAnnotation(
            agent="test",
            entities=[
                EntitySpan(start=3, end=6, label="test"),
            ],
        ),
    )
예제 #4
0
def test_entities_with_spaces():

    text = "This is  a  great  space"
    TokenClassificationRecord(
        text=text,
        tokens=["This", "is", " ", "a", " ", "great", " ", "space"],
        prediction=TokenClassificationAnnotation(
            agent="test",
            entities=[
                EntitySpan(start=9, end=len(text), label="test"),
            ],
        ),
    )
예제 #5
0
파일: service.py 프로젝트: drahnreb/rubrix
    def search(
        self,
        dataset: str,
        owner: Optional[str],
        search: TokenClassificationQuery,
        record_from: int = 0,
        size: int = 100,
    ) -> TokenClassificationSearchResults:
        """
        Run a search in a dataset

        Parameters
        ----------
        dataset:
            The dataset name
        owner:
            The dataset owner
        search:
            The search parameters
        record_from:
            The record from return results
        size:
            The max number of records to return

        Returns
        -------
            The matched records with aggregation info for specified task_meta.py

        """
        dataset = self.__datasets__.find_by_name(dataset, owner=owner)

        results = self.__dao__.search_records(
            dataset,
            search=RecordSearch(query=as_elasticsearch(search)),
            size=size,
            record_from=record_from,
        )
        return TokenClassificationSearchResults(
            total=results.total,
            records=[
                TokenClassificationRecord.parse_obj(r) for r in results.records
            ],
            aggregations=TokenClassificationAggregations(
                **results.aggregations,
                words=results.words,
                metadata=results.metadata or {},
            ) if results.aggregations else None,
        )
예제 #6
0
def test_char_position():

    with pytest.raises(
            ValidationError,
            match="End character cannot be placed before the starting character,"
            " it must be at least one character after.",
    ):
        EntitySpan(start=1, end=1, label="label")

    text = "I am Maxi"
    TokenClassificationRecord(
        text=text,
        tokens=text.split(),
        prediction=TokenClassificationAnnotation(
            agent="test", entities=[EntitySpan(start=0, end=1, label="test")]),
    )
예제 #7
0
def test_entity_label_too_long():
    text = "On one ones o no"
    with pytest.raises(ValidationError,
                       match="ensure this value has at most 128 character"):
        TokenClassificationRecord(
            text=text,
            tokens=text.split(),
            prediction=TokenClassificationAnnotation(
                agent="test",
                entities=[
                    EntitySpan(
                        start=9,
                        end=len(text),
                        label="a" * 1000,
                    ),
                ],
            ),
        )
예제 #8
0
파일: service.py 프로젝트: recognai/rubrix
    def add_records(
        self,
        dataset: str,
        owner: Optional[str],
        records: List[CreationTokenClassificationRecord],
    ):
        dataset = self.__datasets__.find_by_name(dataset, owner=owner)

        db_records = []
        now = datetime.datetime.now()
        for record in records:
            db_record = TokenClassificationRecord.parse_obj(record)
            db_record.last_updated = now
            db_records.append(db_record.dict(exclude_none=True))

        failed = self.__dao__.add_records(
            dataset=dataset,
            records=db_records,
        )
        return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed)
예제 #9
0
파일: service.py 프로젝트: recognai/rubrix
    def search(
        self,
        dataset: str,
        owner: Optional[str],
        search: TokenClassificationQuery,
        record_from: int = 0,
        size: int = 100,
    ) -> TokenClassificationSearchResults:
        """
        Run a search in a dataset

        Parameters
        ----------
        dataset:
            The dataset name
        owner:
            The dataset owner
        search:
            The search parameters
        record_from:
            The record from return results
        size:
            The max number of records to return

        Returns
        -------
            The matched records with aggregation info for specified task_meta.py

        """
        dataset = self.__datasets__.find_by_name(dataset, owner=owner)

        results = self.__dao__.search_records(
            dataset,
            search=RecordSearch(
                query=as_elasticsearch(search),
                aggregations={
                    **aggregations.nested_aggregation(
                        name=PREDICTED_MENTIONS_ES_FIELD_NAME,
                        nested_path=PREDICTED_MENTIONS_ES_FIELD_NAME,
                        inner_aggregation=aggregations.bidimentional_terms_aggregations(
                            name=PREDICTED_MENTIONS_ES_FIELD_NAME,
                            field_name_x=PREDICTED_MENTIONS_ES_FIELD_NAME + ".entity",
                            field_name_y=PREDICTED_MENTIONS_ES_FIELD_NAME + ".mention",
                        ),
                    ),
                    **aggregations.nested_aggregation(
                        name=MENTIONS_ES_FIELD_NAME,
                        nested_path=MENTIONS_ES_FIELD_NAME,
                        inner_aggregation=aggregations.bidimentional_terms_aggregations(
                            name=MENTIONS_ES_FIELD_NAME,
                            field_name_x=MENTIONS_ES_FIELD_NAME + ".entity",
                            field_name_y=MENTIONS_ES_FIELD_NAME + ".mention",
                        ),
                    ),
                },
            ),
            size=size,
            record_from=record_from,
        )
        return TokenClassificationSearchResults(
            total=results.total,
            records=[TokenClassificationRecord.parse_obj(r) for r in results.records],
            aggregations=TokenClassificationAggregations(
                **results.aggregations,
                words=results.words,
                metadata=results.metadata or {},
            )
            if results.aggregations
            else None,
        )