예제 #1
0
def map_annotations_strict(f):
    """Creates a Spark UDF to map over an Annotator's results, for which the
    return type is explicitly defined as a `Annotation.dataType()`.

    Parameters
    ----------
    f : function
        The function to be applied over the results

    Returns
    -------
    :func:`pyspark.sql.functions.udf`
        Spark UserDefinedFunction (udf)

    Examples
    --------
    >>> from sparknlp.pretrained import PretrainedPipeline
    >>> explain_document_pipeline = PretrainedPipeline("explain_document_dl")
    >>> data = spark.createDataFrame([["U.N. official Ekeus heads for Baghdad."]]).toDF("text")
    >>> result = explain_document_pipeline.transform(data)
    >>> def nnp_tokens(annotations):
    ...     return list(
    ...         filter(lambda annotation: annotation.result == 'NNP', annotations)
    ...     )
    >>> result.select(
    ...     map_annotations_strict(nnp_tokens)('pos').alias("nnp")
    ... ).selectExpr("explode(nnp) as nnp").show(truncate=False)
    +-----------------------------------------+
    |nnp                                      |
    +-----------------------------------------+
    |[pos, 0, 2, NNP, [word -> U.N], []]      |
    |[pos, 14, 18, NNP, [word -> Epeus], []]  |
    |[pos, 30, 36, NNP, [word -> Baghdad], []]|
    +-----------------------------------------+
    """
    return udf(
        lambda content: [
            Annotation.toRow(a)
            for a in f([Annotation.fromRow(r) for r in content])
        ], ArrayType(Annotation.dataType()))
예제 #2
0
def map_annotations_strict(f):
    from sparknlp.annotation import Annotation
    sys.modules[
        'sparknlp.annotation'] = sparknlp  # Makes Annotation() pickle serializable in top-level
    return udf(lambda content: f(content), ArrayType(Annotation.dataType()))
예제 #3
0
def map_annotations_strict(f):
    return udf(
        lambda content: [
            Annotation.toRow(a)
            for a in f([Annotation.fromRow(r) for r in content])
        ], ArrayType(Annotation.dataType()))
예제 #4
0
def map_annotations_strict(f):
    return udf(
        lambda content: f(content),
        ArrayType(Annotation.dataType())
    )