def datediff(end, start): """ Returns the number of days from `start` to `end`. >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() [Row(diff=32)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start)))
def levenshtein(left, right): """Computes the Levenshtein distance of the two given strings. >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) return Column(jc)
def months_between(date1, date2): """ Returns the number of months between date1 and date2. >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) >>> df.select(months_between(df.t, df.d).alias('months')).collect() [Row(months=3.9495967...)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context if rsd is None: jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) else: jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) return Column(jc)
def to_avro(data): """ Converts a column into binary of avro format. Note: Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per the deployment section of "Apache Avro Data Source Guide". :param data: the data column. >>> from pyspark.sql import Row >>> from pyspark.sql.avro.functions import to_avro >>> data = [(1, Row(name='Alice', age=2))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_avro(df.value).alias("avro")).collect() [Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))] """ sc = SparkContext._active_spark_context try: jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data)) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": _print_missing_jar("Avro", "avro", "avro", sc.version) raise return Column(jc)
def _(*cols): jcontainer = self.get_java_container(package_name=package_name, object_name=object_name, java_class_instance=java_class_instance) # Ensure that your argument is a column function = getattr(jcontainer, name) judf = function() jc = judf.apply(self.to_scala_seq([_to_java_column(c) for c in cols])) return Column(jc)
def decode(col, charset): """ Computes the first argument into a string from a binary using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.decode(_to_java_column(col), charset))
def shiftLeft(col, numBits): """Shift the the given value numBits left. >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): """ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the given format. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format))
def length(col): """Calculates the length of a string or binary expression. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col)))
def log(arg1, arg2=None): """Returns the first argument-based logarithm of the second argument. If there is only one argument, then this takes the natural logarithm of the argument. >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() ['0.69314', '1.60943'] """ sc = SparkContext._active_spark_context if arg2 is None: jc = sc._jvm.functions.log(_to_java_column(arg1)) else: jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) return Column(jc)
def initcap(col): """Translate the first letter of each word to upper case in the sentence. >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() [Row(v=u'Ab Cd')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.initcap(_to_java_column(col)))
def log2(col): """Returns the base-2 logarithm of the argument. >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() [Row(log2=2.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.log2(_to_java_column(col)))
def _convertDF(df, sp_key = None, metadata = None): ctx = SparkContext._active_spark_context._rf_context if sp_key is None: return RasterFrame(ctx._jrfctx.asRF(df._jdf), ctx._spark_session) else: import json return RasterFrame(ctx._jrfctx.asRF( df._jdf, _to_java_column(sp_key), json.dumps(metadata)), ctx._spark_session)
def bin(col): """Returns the string representation of the binary value of the given column. >>> df.select(bin(df.age).alias('c')).collect() [Row(c=u'10'), Row(c=u'101')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.bin(_to_java_column(col)) return Column(jc)
def sha1(col): """Returns the hex string result of SHA-1. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.sha1(_to_java_column(col)) return Column(jc)
def unhex(col): """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.unhex(_to_java_column(col)))
def shiftRight(col, numBits): """Shift the the given value numBits right. >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) return Column(jc)
def md5(col): """Calculates the MD5 digest and returns the value as a 32 character hex string. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.md5(_to_java_column(col)) return Column(jc)
def shiftRightUnsigned(col, numBits): """Unsigned shift the the given value numBits right. >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ .collect() [Row(r=9223372036854775787)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) return Column(jc)
def format_number(col, d): """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, and returns the result as a string. :param col: the column name of the numeric value to be formatted :param d: the N decimal places >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() [Row(v=u'5.0000')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
def add_months(start, months): """ Returns the date that is `months` months after `start` >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(add_months(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 5, 8))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
def date_sub(start, days): """ Returns the date that is `days` days before `start` >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(date_sub(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 4, 7))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
def last_day(date): """ Returns the last day of the month which the given date belongs to. >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) >>> df.select(last_day(df.d).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.last_day(_to_java_column(date)))
def crc32(col): """ Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.crc32(_to_java_column(col)))
def to_utc_timestamp(timestamp, tz): """ Assumes given timestamp is in given timezone and converts to UTC. >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
def regexp_extract(str, pattern, idx): """Extract a specific(idx) group identified by a java regex, from the specified string column. >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) return Column(jc)
def to_date(col): """ Converts the column of StringType or TimestampType into DateType. >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_date(df.t).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.to_date(_to_java_column(col)))
def regexp_replace(str, pattern, replacement): """Replace all substrings of the specified string value that match regexp with rep. >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect() [Row(d=u'##-##')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) return Column(jc)
def hex(col): """Computes hex value of the given column, which could be StringType, BinaryType, IntegerType or LongType. >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() [Row(hex(a)=u'414243', hex(b)=u'3')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.hex(_to_java_column(col)) return Column(jc)
def subset_struct(struct: Union[Column, str], *fields: str) -> Column: """ Selects fields from a struct. Added in version 0.3.0. Examples: >>> df = spark.createDataFrame([Row(struct=Row(a=1, b=2, c=3))]) >>> df.select(glow.subset_struct('struct', 'a', 'c').alias('struct')).collect() [Row(struct=Row(a=1, c=3))] Args: struct : Struct from which to select fields fields : Fields to select Returns: A struct containing only the indicated fields """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.subset_struct( _to_java_column(struct), _to_seq(sc(), fields))) assert check_return_type(output) return output
def vector_to_array(vector: Union[Column, str]) -> Column: """ Converts a ``spark.ml`` ``Vector`` (sparse or dense) to an array of doubles. Added in version 0.3.0. Examples: >>> from pyspark.ml.linalg import DenseVector, SparseVector >>> df = spark.createDataFrame([Row(v=SparseVector(3, {0: 1.0, 2: 2.0})), Row(v=DenseVector([3.0, 4.0]))]) >>> df.select(glow.vector_to_array('v').alias('arr')).collect() [Row(arr=[1.0, 0.0, 2.0]), Row(arr=[3.0, 4.0])] Args: vector : Vector to convert Returns: An array of doubles """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.vector_to_array( _to_java_column(vector))) assert check_return_type(output) return output
def from_avro(col): jvm_gateway = sc._active_spark_context._gateway.jvm abris_avro = jvm_gateway.za.co.absa.abris.avro naming_strategy = getattr( getattr(abris_avro.read.confluent.SchemaManager, "SchemaStorageNamingStrategies$"), "MODULE$").TOPIC_NAME() schema_registry_config_dict = { "schema.registry.url": spark.conf.get("schemaRegistryUrl"), "schema.registry.topic": spark.conf.get("readTopic"), "{col}.schema.id".format(col=col): "latest", "{col}.schema.naming.strategy".format(col=col): naming_strategy } conf_map = getattr( getattr(jvm_gateway.scala.collection.immutable.Map, "EmptyMap$"), "MODULE$") for k, v in schema_registry_config_dict.items(): conf_map = getattr(conf_map, "$plus")(jvm_gateway.scala.Tuple2(k, v)) return Column( abris_avro.functions.from_confluent_avro(_to_java_column(col), conf_map))
def call_summary_stats(genotypes: Union[Column, str]) -> Column: """ Computes call summary statistics for an array of genotype structs. See :ref:`variant-qc` for more details. Added in version 0.3.0. Examples: >>> schema = 'genotypes: array<struct<calls: array<int>>>' >>> df = spark.createDataFrame([Row(genotypes=[Row(calls=[0, 0]), Row(calls=[1, 0]), Row(calls=[1, 1])])], schema) >>> df.select(glow.expand_struct(glow.call_summary_stats('genotypes'))).collect() [Row(callRate=1.0, nCalled=3, nUncalled=0, nHet=1, nHomozygous=[1, 1], nNonRef=2, nAllelesCalled=6, alleleCounts=[3, 3], alleleFrequencies=[0.5, 0.5])] Args: genotypes : The array of genotype structs with ``calls`` field Returns: A struct containing ``callRate``, ``nCalled``, ``nUncalled``, ``nHet``, ``nHomozygous``, ``nNonRef``, ``nAllelesCalled``, ``alleleCounts``, ``alleleFrequencies`` fields. See :ref:`variant-qc`. """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.call_summary_stats( _to_java_column(genotypes))) assert check_return_type(output) return output
def array_to_dense_vector(arr: Union[Column, str]) -> Column: """ Converts an array of numerics into a ``spark.ml`` ``DenseVector``. Added in version 0.3.0. Examples: >>> from pyspark.ml.linalg import DenseVector >>> df = spark.createDataFrame([Row(arr=[1, 2, 3])]) >>> df.select(glow.array_to_dense_vector('arr').alias('v')).collect() [Row(v=DenseVector([1.0, 2.0, 3.0]))] Args: arr : The array of numerics Returns: A ``spark.ml`` ``DenseVector`` """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.array_to_dense_vector( _to_java_column(arr))) assert check_return_type(output) return output
def from_avro(data, jsonFormatSchema, options={}): """ Converts a binary column of avro format into its corresponding catalyst value. The specified schema must match the read data, otherwise the behavior is undefined: it may fail or return arbitrary result. Note: Avro is built-in but external data source module since Spark 2.4. Please deploy the application as per the deployment section of "Apache Avro Data Source Guide". :param data: the binary column. :param jsonFormatSchema: the avro schema in JSON string format. :param options: options to control how the Avro record is parsed. >>> from pyspark.sql import Row >>> from pyspark.sql.avro.functions import from_avro, to_avro >>> data = [(1, Row(name='Alice', age=2))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> avroDf = df.select(to_avro(df.value).alias("avro")) >>> avroDf.collect() [Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))] >>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields": ... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord", ... "fields":[{"name":"age","type":["long","null"]}, ... {"name":"name","type":["string","null"]}]},"null"]}]}''' >>> avroDf.select(from_avro(avroDf.avro, jsonFormatSchema).alias("value")).collect() [Row(value=Row(avro=Row(age=2, name=u'Alice')))] """ sc = SparkContext._active_spark_context try: jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro( _to_java_column(data), jsonFormatSchema, options) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": _print_missing_jar("Avro", "avro", "avro", sc.version) raise return Column(jc)
def from_avro(dfcol: Column, jsonformatschema: str) -> Column: """ Decode the Avro data contained in a DataFrame column into a struct. Note: Pyspark does not have all features contained in Spark core (Scala), hence we provide here a wrapper around the Scala function `from_avro`. You need to have the package org.apache.spark:spark-avro_2.11:2.x.y in the classpath to have access to it from the JVM. Parameters ---------- dfcol: Column Streaming DataFrame Column with encoded Avro data (binary). Typically this is what comes from reading stream from Kafka. jsonformatschema: str Avro schema in JSON string format. Returns ---------- out: Column DataFrame Column with decoded Avro data. Examples ---------- >>> _, _, alert_schema_json = get_schemas_from_avro(ztf_alert_sample) >>> df_decoded = dfstream.select( ... from_avro(dfstream["value"], alert_schema_json).alias("decoded")) >>> query = df_decoded.writeStream.queryName("qraw").format("memory") >>> t = query.outputMode("update").start() >>> t.stop() """ sc = SparkContext._active_spark_context avro = sc._jvm.org.apache.spark.sql.avro f = getattr(getattr(avro, "package$"), "MODULE$").from_avro return Column(f(_to_java_column(dfcol), jsonformatschema))
def add_struct_fields(struct: Union[Column, str], *fields: Union[Column, str]) -> Column: """ Adds fields to a struct. Added in version 0.3.0. Examples: >>> df = spark.createDataFrame([Row(struct=Row(a=1))]) >>> df.select(glow.add_struct_fields('struct', lit('b'), lit(2)).alias('struct')).collect() [Row(struct=Row(a=1, b=2))] Args: struct : The struct to which fields will be added fields : The new fields to add. The arguments must alternate between string-typed literal field names and field values. Returns: A struct consisting of the input struct and the added fields """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.add_struct_fields( _to_java_column(struct), _to_seq(sc(), fields, _to_java_column))) assert check_return_type(output) return output
def explode_matrix(matrix: Union[Column, str]) -> Column: """ Explodes a ``spark.ml`` ``Matrix`` (sparse or dense) into multiple arrays, one per row of the matrix. Added in version 0.3.0. Examples: >>> from pyspark.ml.linalg import DenseMatrix >>> m = DenseMatrix(numRows=3, numCols=2, values=[1, 2, 3, 4, 5, 6]) >>> df = spark.createDataFrame([Row(matrix=m)]) >>> df.select(glow.explode_matrix('matrix').alias('row')).collect() [Row(row=[1.0, 4.0]), Row(row=[2.0, 5.0]), Row(row=[3.0, 6.0])] Args: matrix : The ``sparl.ml`` ``Matrix`` to explode Returns: An array column in which each row is a row of the input matrix """ assert check_argument_types() output = Column(sc()._jvm.io.projectglow.functions.explode_matrix( _to_java_column(matrix))) assert check_return_type(output) return output
def lift_over_coordinates(contigName: Union[Column, str], start: Union[Column, str], end: Union[Column, str], chainFile: str, minMatchRatio: float = None) -> Column: """ Performs liftover for the coordinates of a variant. To perform liftover of alleles and add additional metadata, see :ref:`liftover`. Added in version 0.3.0. Examples: >>> df = spark.read.format('vcf').load('test-data/liftover/unlifted.test.vcf').where('start = 18210071') >>> chain_file = 'test-data/liftover/hg38ToHg19.over.chain.gz' >>> reference_file = 'test-data/liftover/hg19.chr20.fa.gz' >>> df.select('contigName', 'start', 'end').head() Row(contigName='chr20', start=18210071, end=18210072) >>> lifted_df = df.select(glow.expand_struct(glow.lift_over_coordinates('contigName', 'start', 'end', chain_file))) >>> lifted_df.head() Row(contigName='chr20', start=18190715, end=18190716) Args: contigName : The current contig name start : The current start end : The current end chainFile : Location of the chain file on each node in the cluster minMatchRatio : Minimum fraction of bases that must remap to do liftover successfully. If not provided, defaults to ``0.95``. Returns: A struct containing ``contigName``, ``start``, and ``end`` fields after liftover """ assert check_argument_types() if minMatchRatio is None: output = Column( sc()._jvm.io.projectglow.functions.lift_over_coordinates( _to_java_column(contigName), _to_java_column(start), _to_java_column(end), chainFile)) else: output = Column( sc()._jvm.io.projectglow.functions.lift_over_coordinates( _to_java_column(contigName), _to_java_column(start), _to_java_column(end), chainFile, minMatchRatio)) assert check_return_type(output) return output
def hard_calls(probabilities: Union[Column, str], numAlts: Union[Column, str], phased: Union[Column, str], threshold: float = None) -> Column: """ Converts an array of probabilities to hard calls. The probabilities are assumed to be diploid. See :ref:`variant-data-transformations` for more details. Added in version 0.3.0. Examples: >>> df = spark.createDataFrame([Row(probs=[0.95, 0.05, 0.0])]) >>> df.select(glow.hard_calls('probs', numAlts=lit(1), phased=lit(False)).alias('calls')).collect() [Row(calls=[0, 0])] >>> df = spark.createDataFrame([Row(probs=[0.05, 0.95, 0.0])]) >>> df.select(glow.hard_calls('probs', numAlts=lit(1), phased=lit(False)).alias('calls')).collect() [Row(calls=[0, 1])] >>> # Use the threshold parameter to change the minimum probability required for a call >>> df = spark.createDataFrame([Row(probs=[0.05, 0.95, 0.0])]) >>> df.select(glow.hard_calls('probs', numAlts=lit(1), phased=lit(False), threshold=0.99).alias('calls')).collect() [Row(calls=[-1, -1])] Args: probabilities : The array of probabilities to convert numAlts : The number of alternate alleles phased : Whether the probabilities are phased. If phased, we expect one ``2 * numAlts`` values in the probabilities array. If unphased, we expect one probability per possible genotype. threshold : The minimum probability to make a call. If no probability falls into the range of ``[0, 1 - threshold]`` or ``[threshold, 1]``, a no-call (represented by ``-1`` s) will be emitted. If not provided, this parameter defaults to ``0.9``. Returns: An array of hard calls """ assert check_argument_types() if threshold is None: output = Column(sc()._jvm.io.projectglow.functions.hard_calls( _to_java_column(probabilities), _to_java_column(numAlts), _to_java_column(phased))) else: output = Column(sc()._jvm.io.projectglow.functions.hard_calls( _to_java_column(probabilities), _to_java_column(numAlts), _to_java_column(phased), threshold)) assert check_return_type(output) return output
def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column self.assertTrue("Column" in _to_java_column("a").getClass().toString()) self.assertTrue("Column" in _to_java_column("a").getClass().toString()) self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) self.assertRaisesRegex( TypeError, "Invalid argument, not a string or column", lambda: _to_java_column(1) ) class A: pass self.assertRaises(TypeError, lambda: _to_java_column(A())) self.assertRaises(TypeError, lambda: _to_java_column([])) self.assertRaisesRegex( TypeError, "Invalid argument, not a string or column", lambda: udf(lambda x: x)(None) ) self.assertRaises(TypeError, lambda: to_json(1))
def add_suffix(suffix, df, common_cols): cols = [ Column(_to_java_column(c)).alias(c + suffix) if c in common_cols else c for c in df.columns ] return df.select(*cols)
def _(geometryCol, srcCRSName, dstCRSName): jfcn = RFContext.active().lookup('reprojectGeometry') return Column( jfcn(_to_java_column(geometryCol), srcCRSName, dstCRSName))
def _(geometryCol, boundsCol, valueCol, numCols, numRows): jfcn = RFContext.active().lookup('rasterize') return Column( jfcn(_to_java_column(geometryCol), _to_java_column(boundsCol), _to_java_column(valueCol), numCols, numRows))
def _(col, scalar): jfcn = getattr(_checked_context(), name) return Column(jfcn(_to_java_column(col), scalar))
def _(*args): jfcn = RFContext.active().lookup('explodeTiles') jcols = [_to_java_column(arg) for arg in args] return Column(jfcn(RFContext.active().list_to_seq(jcols)))
def _(geometryCol, srcCRSName, dstCRSName): jfcn = getattr(_checked_context(), 'reprojectGeometry') return Column( jfcn(_to_java_column(geometryCol), srcCRSName, dstCRSName))
def explode_outer(col): sc = SparkContext._active_spark_context _explode_outer = sc._jvm.org.apache.spark.sql.functions.explode_outer return Column(_explode_outer(_to_java_column(col)))
def _(*args): jfcn = RFContext.active().lookup(name) jcols = [_to_java_column(arg) for arg in args] return Column(jfcn(*jcols))
def _(*args): jfcn = getattr(_checked_context(), name) jcols = [_to_java_column(arg) for arg in args] return Column(jfcn(*jcols))
def _(col, scalar): jfcn = RFContext.active().lookup(name) return Column(jfcn(_to_java_column(col), scalar))
def _(tileCol, cellType): jfcn = getattr(_checked_context(), 'convertCellType') return Column(jfcn(_to_java_column(tileCol), _celltype(cellType)))
def _(tileCol, cellType): jfcn = RFContext.active().lookup('convertCellType') return Column(jfcn(_to_java_column(tileCol), _celltype(cellType)))
def explode_outer_(col): _explode_outer = spark.sparkContext._jvm.org.apache.spark.sql.functions.explode_outer return Column(_explode_outer(_to_java_column(col)))
def replaceArrayElement(srcCol, replaceCol, idx): sc = SparkContext._active_spark_context jsrcCol, jreplaceCol = _to_java_column(srcCol), _to_java_column(replaceCol) return Column(sc._jvm.gluefunctions.replaceArrayElement(jsrcCol, jreplaceCol, idx))
def _(arrayCol, numCols, numRows): jfcn = getattr(_checked_context(), 'arrayToTile') return Column(jfcn(_to_java_column(arrayCol), numCols, numRows))
def _(colIndex, rowIndex, cellData, numCols, numRows, cellType): jfcn = RFContext.active().lookup('assembleTile') return Column( jfcn(_to_java_column(colIndex), _to_java_column(rowIndex), _to_java_column(cellData), numCols, numRows, _celltype(cellType)))
def _(arrayCol, numCols, numRows): jfcn = RFContext.active().lookup('arrayToTile') return Column(jfcn(_to_java_column(arrayCol), numCols, numRows))
def explodeWithIndex(col): sc = SparkContext._active_spark_context jc = sc._jvm.gluefunctions.explodeWithIndex(_to_java_column(col)) return Column(jc).alias('index', 'val')