def getStateFromFrame(self): my = self.cc.getMyCharacter() myAction = get_field(my, "action").__str__() if ("RECOV" in myAction or "THROW" in myAction): # run away print("running away.") raise RunAsFastAsYouCan opp = self.cc.getEnemyCharacter() state = [[ get_field(my, "x"), get_field(my, "y"), get_field(my, "energy"), get_field(my, "hp"), get_field(opp, "x"), get_field(opp, "y"), get_field(opp, "energy"), get_field(opp, "hp") ]] oppAction = get_field(opp, "action").__str__() onehotted = self.onehot([oppAction]) state = numpy.concatenate((state, onehotted), axis=1) return state
def from_java(cls, java_map_info: JavaObject) -> "MapInfo": if java_map_info == None: return None else: key_type = java_gateway.get_field(java_map_info, "keyType").toString() value_type = java_gateway.get_field(java_map_info, "valueType").toString() size = java_gateway.get_field(java_map_info, "size") return cls(key_type, value_type, size)
def testSetField(self): self.gateway = JavaGateway(gateway_parameters=GatewayParameters(auto_field=False)) ex = self.gateway.getNewExample() set_field(ex, "field10", 2334) self.assertEquals(get_field(ex, "field10"), 2334) sb = self.gateway.jvm.java.lang.StringBuffer("Hello World!") set_field(ex, "field21", sb) self.assertEquals(get_field(ex, "field21").toString(), "Hello World!") self.assertRaises(Exception, set_field, ex, "field1", 123)
def testSetField(self): self.gateway = JavaGateway(auto_field=False) ex = self.gateway.getNewExample() set_field(ex, 'field10', 2334) self.assertEquals(get_field(ex, 'field10'), 2334) sb = self.gateway.jvm.java.lang.StringBuffer('Hello World!') set_field(ex, 'field21', sb) self.assertEquals(get_field(ex, 'field21').toString(), 'Hello World!') self.assertRaises(Exception, set_field, ex, 'field1', 123)
def __convert_fitness_j_to_p(self, f): return Fitness(value=get_field(f, "value"), u_sell=get_field(f, "uSell"), u_buy=get_field(f, "uBuy"), noop=get_field(f, "noop"), realised_profit=get_field(f, "realisedProfit"), mdd=get_field(f, "MDD"), ret=get_field(f, "Return"), wealth=get_field(f, "wealth"), no_of_transactions=get_field(f, "noOfTransactions"), no_of_short_selling_transactions=get_field( f, "noOfShortSellingTransactions"))
def testSetField(self): self.gateway = JavaGateway(gateway_parameters=GatewayParameters( auto_field=False)) ex = self.gateway.getNewExample() set_field(ex, "field10", 2334) self.assertEquals(get_field(ex, "field10"), 2334) sb = self.gateway.jvm.java.lang.StringBuffer("Hello World!") set_field(ex, "field21", sb) self.assertEquals(get_field(ex, "field21").toString(), "Hello World!") self.assertRaises(Exception, set_field, ex, "field1", 123)
def build_univariate_detector(job_conf, jvm): univariate_detector_type = job_conf['univariate_detector_type'] if univariate_detector_type == "CUSUM": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.cusum( ) elif univariate_detector_type == "ADWIN": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.adwin( ) elif univariate_detector_type == "GEOMETRIC_MOVING_AVERAGE": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.geometricMovingAverage( ) elif univariate_detector_type == "DDM": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.ddm( ) elif univariate_detector_type == "EDDM": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.eddm( ) elif univariate_detector_type == "EWMA": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.ewma( ) elif univariate_detector_type == "PAGE_HINKLEY": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.pageHinkley( ) elif univariate_detector_type == "HDDM_A": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.hddmA( ) elif univariate_detector_type == "HDDM_W": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.hddmW( ) elif univariate_detector_type == "SEED": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.seed( ) elif univariate_detector_type == "SEQ1": univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.seq1( ) # elif univariate_detector_type == "SEQ2": # univariate_detector = jvm.uk.ac.bangor.meander.detectors.Detectors.Univariate.seq2() else: raise Exception("Univariate drift detector %s not implemented" % univariate_detector_type) options = job_conf['univariate_detector_options'] for k, v in options.items(): moaDetector = get_field(univariate_detector, 'moaDetector') if v['type'] == 'MultiChoiceOption': get_field(moaDetector, k).setChosenIndex(v['value']) else: get_field(moaDetector, k).setValue(v['value']) return univariate_detector
def predict(self, query_string, body_data): if self._verbose: self._logger.debug( self._prefix_msg + "predict, query_string: {}, body_data: {}".format(query_string, body_data)) if self._model_loaded: result = self._component_via_py4j.predict(query_string, body_data) returned_code = get_field(result, "returned_code") json = get_field(result, "json") if self._verbose: self._logger.debug( self._prefix_msg + "got response ... code: {}, json: {}".format(returned_code, str(json))) return (returned_code, str(json)) else: return 404, '{"error": "H2O model was not loaded yet!"}'
def testSetField(self): self.gateway = JavaGateway(auto_field=False) ex = self.gateway.getNewExample() set_field(ex, "field10", 2334) self.assertEquals(get_field(ex, "field10"), 2334) sb = self.gateway.jvm.java.lang.StringBuffer("Hello World!") set_field(ex, "field21", sb) self.assertEquals(get_field(ex, "field21").toString(), "Hello World!") try: set_field(ex, "field1", 123) self.fail() except Exception: self.assertTrue(True)
def deserializeNetwork(self, network): neurons = [] for i, neuron in enumerate(network): neurons.append(Neuron(neuron, i)) if get_field(neuron, 'index') == self.outputNeuron: break return neurons
def view_data(key): data_info = get_data_info(key) if not data_info or 'data' not in data_info: raise Exception('no data for ' + key) data = data_info['data'] type_name = status_gateway.get_field(data_info['status'], 'typeName') if isinstance(data, pd.DataFrame): return data_json.to_json({ 'type': 'table', 'data': { 'count': data.shape[0], 'bytes': -1, 'schema': [{'column-name': c, 'column-type': data.dtypes[c].name} for c in data.columns.to_series()], 'data': [[d[col_name] for col_name in data.columns] for index, d in data.iterrows()] } }) elif psdf and isinstance(data, psdf.DataFrame): return data_json.to_json({ 'type': 'table', 'data': { 'count': data.count(), 'bytes': -1, 'schema': [{'column-name': c[0], 'column-type': c[1]} for c in data.dtypes], 'data': [[col for col in row] for row in data.collect()] } }) else: return data_json.to_json({ 'type': type_name, 'data': data })
def partitionTree(self) -> JvmPartitioner: """ :return: """ return JvmPartitioner(get_field(self._srdd, "partitionTree"))
def get_jvm_reporter(self, jvm): jvm_zeppelin_context = self.zeppelin_context.z jvm_intp_context = jvm_zeppelin_context.getInterpreterContext() jvm_output_stream = get_field(jvm_intp_context, "out") jvm_print_stream = jvm.java.io.PrintStream(jvm_output_stream) return jvm.de.frosner.ddq.reporters.ZeppelinReporter( jvm_print_stream )
def indexedRawRDD(self): jrdd = get_field(self._srdd, "indexedRawRDD") if not hasattr(self, "_indexed_raw_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) RDD.count = lambda x: x._jrdd.count() setattr(self, "_indexed_raw_rdd", RDD(jrdd, self._sc)) else: self._indexed_raw_rdd._jrdd = jrdd return getattr(self, "_indexed_raw_rdd")
def boundaryEnvelope(self) -> Envelope: """ :return: """ if not self._is_analyzed: raise TypeError("Please use analyze before") java_boundary_envelope = get_field(self._srdd, "boundaryEnvelope") return Envelope.from_jvm_instance(java_boundary_envelope)
def __getattr__(self, item): if "java_fields" in self.__dict__ and item in self.__dict__["java_fields"]: a = java_gateway.get_field(self.java_obj,item) if gateway.jvm.pyboof.PyBoofEntryPoint.isConfigClass(a): return JavaConfig(a) else: return a else: return object.__getattribute__(self, item)
def __getattr__(self, item): if "java_fields" in self.__dict__ and item in self.__dict__[ "java_fields"]: a = java_gateway.get_field(self.java_obj, item) if gateway.jvm.pyboof.PyBoofEntryPoint.isConfigClass(a): return JavaConfig(a) else: return a else: return object.__getattribute__(self, item)
def fieldNames(self) -> List[str]: """ :return: """ try: field_names = list(get_field(self._srdd, "fieldNames")) except TypeError: field_names = [] return field_names
def kmeans(data, centers=2, runs=5, max_iters=10): """ Train Kmeans on a given DDF :param data: DDF :param centers: number of clusters :param runs: number of runs :param max_iters: number of iterations :return: an object of type KMeansModel """ ml_obj = java_gateway.get_field(data._jddf, 'ML') return KMeansModel(ml_obj.KMeans(centers, runs, max_iters), data._gateway_client, data.colnames)
def deserialize_weights(self, neuralModel): weights = [0] * len(neuralModel) for weight in neuralModel: w = Weight(weight) w.postprocess() dim = tuple(get_field(weight, 'dimensions')) if not dim: dim = (1, ) w.expr = self.dynet_model.add_parameters(dim) weights[w.index] = w.expr return weights
def is_regular_service(name): serv_model=oc.service_model(name) params=serv_model.getModelParamList() is_regular=True for param in params: internal = get_field(param, "internal") p_name = get_field(param, "name") # print(p_name, internal) if internal: pass elif get_field(param, 'mode')!='OUT': param_type=get_field(param, 'type') if param_type=='java.nio.ByteBuffer': print("service {} parameter {}'s type is ByteBuffer" .format(name, p_name)) is_regular=False break elif param_type not in known_in_types: is_regular=False break return is_regular
def ground_NNs(args): sources = main.getSources(prepare_args(args), settings) nnbuilder = neuralogic.pipelines.building.End2endNNBuilder( settings, sources) pipeline = nnbuilder.buildPipeline() result = pipeline.execute(sources) samples = get_field(get_field(result, 's'), 's') samples = list(samples) logic = get_field(get_field(get_field(result, 's'), 'r'), 'r') neuralParams = get_field(get_field(get_field(result, 's'), 'r'), 's') return samples, list(neuralParams), logic
def classify(text, RNN): # Make a gateway, end the text and get back a Treepack object connection = JavaGateway() treepack = connection.entry_point.parse(text) # Time to classify! treecount = gateway.get_field(treepack, 'count') if treecount == 1: # Make a tree and classify print 'Making tree...' tree = Tree(_id=-1) tree.root = Node() sentence = gateway.get_field(treepack, 'first') tree.root.read(sentence, 0, True) # Chaipi as usual tree.root = tree.root.children[0] prediction_vector = RNN.predict_tree(tree) print prediction_vector elif treecount == 2: # Make a request for classification. Note: request = 2 trees print 'Making request...' request = Request() first_sentence = gateway.get_field(treepack, 'first') second_sentence = gateway.get_field(treepack, 'second') request.add_tree(first_sentence, -1) request.add_tree(second_sentence, -1) prediction_vector = RNN.predict_request(request) print prediction_vector else: # Probably return some kind of error to the client prediction_vector = None return prediction_vector
def spatialPartitionedRDD(self): """ :return: """ serialized_spatial_rdd = GeoSparkPythonConverter(self._jvm).translate_spatial_rdd_to_python( get_field(self._srdd, "spatialPartitionedRDD")) if not hasattr(self, "_spatial_partitioned_rdd"): setattr(self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd")
def testNoAutoField(self): self.gateway = JavaGateway(auto_field=False) ex = self.gateway.getNewExample() self.assertTrue(isinstance(ex.field10, JavaMember)) self.assertTrue(isinstance(ex.field50, JavaMember)) self.assertEqual(10, get_field(ex, 'field10')) try: get_field(ex, 'field50') self.fail() except Exception: self.assertTrue(True) ex._auto_field = True sb = ex.field20 sb.append('Hello') self.assertEqual('Hello', sb.toString()) try: get_field(ex, 'field20') self.fail() except Exception: self.assertTrue(True)
def __init__(self, neuron, index): self.index = index self.name = get_field(neuron, 'name') self.weighted = get_field(neuron, 'weighted') self.activation = get_field(neuron, 'activation') self.inputs = get_field(neuron, 'inputs') self.weights = get_field(neuron, 'weights') self.offset = get_field(neuron, 'offset') self.value = get_field(neuron, 'value') self.pooling = get_field(neuron, 'pooling') self.expr = None
def spatialPartitionedRDD(self): """ :return: """ serialized_spatial_rdd = self._jvm.GeoSerializerData.serializeToPython( get_field(self._srdd, "spatialPartitionedRDD")) if not hasattr(self, "_spatial_partitioned_rdd"): setattr(self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd")
def testNoAutoField(self): self.gateway = JavaGateway(auto_field=False) ex = self.gateway.getNewExample() self.assertTrue(isinstance(ex.field10, JavaMember)) self.assertTrue(isinstance(ex.field50, JavaMember)) self.assertEqual(10, get_field(ex, 'field10')) # This field does not exist self.assertRaises(Exception, get_field, ex, 'field50') # With auto field = True ex._auto_field = True sb = ex.field20 sb.append('Hello') self.assertEqual('Hello', sb.toString())
def extractDATA(pathDss,TituloDSS,model): #getting data record = theFile.get(pathDss, True) values = [p for p in get_field(record,"values")] times = pd.DatetimeIndex(data=[np.datetime64(int(t)*60-2208988800, 's' )/ for t in get_field(record, "times")], name="Time") #printing the most important data to calibrate a model print(model) print(TituloDSS)* print('Peak flow : ', max(values), ' m3/s') print('Time of peak flow : ',times[values.index(max(values))] ) #Integral approximation with trap method volume = np.trapz(values)*60/10**6 print("Volume :", volume,' Hm3') print('') #Plotting graphs f, ax = plt.subplots() ts=pd.TimeSeries(values, index=times) ts.plot() plt.ylabel('Flow m3/s') plt.gca().set_ylim([0,1500]) ax.set_title(TituloDSS)
def testNoAutoField(self): self.gateway = JavaGateway(gateway_parameters=GatewayParameters(auto_field=False)) ex = self.gateway.getNewExample() self.assertTrue(isinstance(ex.field10, JavaMember)) self.assertTrue(isinstance(ex.field50, JavaMember)) self.assertEqual(10, get_field(ex, "field10")) # This field does not exist self.assertRaises(Exception, get_field, ex, "field50") # With auto field = True ex._auto_field = True sb = ex.field20 sb.append("Hello") self.assertEqual("Hello", sb.toString())
def logistic_regression_gd(data, step_size=1.0, max_iters=10): """ :param data: :param step_size: :param max_iters: :return: """ ml_obj = java_gateway.get_field(data._jddf, 'ML') gateway = data._gateway_client model = ml_obj.train('logisticRegressionWithSGD', util.to_java_array([max_iters, step_size], gateway.jvm.Object, gateway)) weights = [float(model.getRawModel().intercept())] + list(model.getRawModel().weights().toArray()) weights = pd.DataFrame(data=[weights], columns=['Intercept'] + data.colnames[:-1]) return LogisticRegressionModel(model, gateway, weights)
def testNoAutoField(self): self.gateway = JavaGateway(gateway_parameters=GatewayParameters( auto_field=False)) ex = self.gateway.getNewExample() self.assertTrue(isinstance(ex.field10, JavaMember)) self.assertTrue(isinstance(ex.field50, JavaMember)) self.assertEqual(10, get_field(ex, "field10")) # This field does not exist self.assertRaises(Exception, get_field, ex, "field50") # With auto field = True ex._auto_field = True sb = ex.field20 sb.append("Hello") self.assertEqual("Hello", sb.toString())
def create_service_data_frame(name, show_internal=False): import pandas as pd model=MetaService(name).model params=model.getModelParamList() model_desc={'name':[str(get_field(param, "name")) for param in params], 'type':[str(get_field(param, "type")) for param in params], 'required': ['*' if not get_field(param, "optional") else ' ' for param in params], 'override optional': ['*' if not get_field(param, "overrideOptional") else ' ' for param in params], 'entity name': [str(get_field(param, "entityName")) for param in params], 'mode':[str(get_field(param, "mode")) for param in params], 'internal': ['*' if get_field(param, "internal") else ' ' for param in params], 'description':[str(get_field(param, "description")) for param in params] } df = pd.DataFrame(model_desc) # df['parameter mode']=df['mode'].astype('category') if not show_internal: df=df[df['internal']!='*'] # return df.sort_values(by='parameter mode') return df
def createLog(cc): my = self.cc.getMyCharacter() energy = my.getEnergy() my_x = get_field(my, 'x') my_y = get_field(my,'y') my_state = get_field(my, "state") opp = self.cc.getEnemyCharacter() opp_x = get_field(opp, 'x') opp_y = get_field(opp,'y') opp_state = get_field(opp, "state")
def print_table(obj, *args): """ Usage: from sagas.ofbiz.util import print_table model=s('model').createProductReview print_table(model, 'location') :param obj: :param args: :return: """ from py4j.java_gateway import get_field from tabulate import tabulate table_header = ['name', 'value'] table_data = [] for arg in args: table_data.append((arg, get_field(obj, arg))) print(tabulate(table_data, headers=table_header, tablefmt='psql'))
def logistic_regression_gd(data, step_size=1.0, max_iters=10): """ :param data: :param step_size: :param max_iters: :return: """ ml_obj = java_gateway.get_field(data._jddf, 'ML') gateway = data._gateway_client model = ml_obj.train( 'logisticRegressionWithSGD', util.to_java_array([max_iters, step_size], gateway.jvm.Object, gateway)) weights = [float(model.getRawModel().intercept())] + list( model.getRawModel().weights().toArray()) weights = pd.DataFrame(data=[weights], columns=['Intercept'] + data.colnames[:-1]) return LogisticRegressionModel(model, gateway, weights)
def collect_component_data_files(self): oc.import_package('org.apache.ofbiz.base.component.ComponentConfig') allComponents = oc.j.ComponentConfig.getAllComponents() index = 1 self.comp_files = {} for c in allComponents: self.comp_files[c.getGlobalName()] = [] for c in allComponents: # print(index, c.getComponentName(), c.getRootLocation()) # print(index, c.getGlobalName(), c.getRootLocation()) index = index + 1 ts = c.getTestSuiteInfos() # print('\t', 'contains test suites', len(ts)) ent_res = c.getEntityResourceInfos() for es in ent_res: data_type = get_field(es, 'type') if data_type == 'data': # print('\t', data_type, get_field(es, 'readerName'), # es.getLocation()) self.comp_files[c.getGlobalName()].append(c.getRootLocation() + es.getLocation()) return self.comp_files
def get_property(self, name): return java_gateway.get_field(self.java_obj,name)
def gapply(grouped_data, func, schema, *cols): """Applies the function ``func`` to data grouped by key. In particular, given a dataframe grouped by some set of key columns key1, key2, ..., keyn, this method groups all the values for each row with the same key columns into a single Pandas dataframe and by default invokes ``func((key1, key2, ..., keyn), values)`` where the number and order of the key arguments is determined by columns on which this instance's parent :class:`DataFrame` was grouped and ``values`` is a ``pandas.DataFrame`` of columns selected by ``cols``, in that order. If there is only one key then the key tuple is automatically unpacked, with ``func(key, values)`` called. ``func`` is expected to return a ``pandas.DataFrame`` of the specified schema ``schema``, which should be of type :class:`StructType` (output columns are of this name and order). If ``spark.conf.get("spark.sql.retainGroupColumns")`` is not ``u'true'``, then ``func`` is called with an empty key tuple (note it is set to ``u'true'`` by default). If no ``cols`` are specified, then all grouped columns will be offered, in the order of the columns in the original dataframe. In either case, the Pandas columns will be named according to the DataFrame column names. The order of the rows passed in as Pandas rows is not guaranteed to be stable relative to the original row order. :note: Users must ensure that the grouped values for every group must fit entirely in memory. :note: This method is only available if Pandas is installed. :param grouped_data: data grouped by key :param func: a two argument function, which may be either a lambda or named function :param schema: the return schema for ``func``, a :class:`StructType` :param cols: list of column names (string only) :raise ValueError: if ``"*"`` is in ``cols`` :raise ValueError: if ``cols`` contains duplicates :raise ValueError: if ``schema`` is not a :class:`StructType` :raise ImportError: if ``pandas`` module is not installed :raise ImportError: if ``pandas`` version is too old (less than 0.7.1) :return: the new :class:`DataFrame` with the original key columns replicated for each returned value in each group's resulting pandas dataframe, the schema being the original key schema prepended to ``schema``, where all the resulting groups' rows are concatenated. Of course, if retaining group columns is disabled, then the output will exactly match ``schema`` since no keys can be prepended. >>> import pandas as pd >>> from pyspark.sql import SparkSession >>> from spark_sklearn.group_apply import gapply >>> from spark_sklearn.util import createLocalSparkSession >>> spark = createLocalSparkSession() >>> df = (spark ... .createDataFrame([Row(course="dotNET", year=2012, earnings=10000), ... Row(course="Java", year=2012, earnings=20000), ... Row(course="dotNET", year=2012, earnings=5000), ... Row(course="dotNET", year=2013, earnings=48000), ... Row(course="Java", year=2013, earnings=30000)]) ... .select("course", "year", "earnings")) >>> def yearlyMedian(_, vals): ... all_years = set(vals['year']) ... # Note that interpolation is performed, so we need to cast back to int. ... yearly_median = [(year, int(vals['earnings'][vals['year'] == year].median())) ... for year in all_years] ... return pd.DataFrame.from_records(yearly_median) >>> newSchema = StructType().add("year", LongType()).add("median_earnings", LongType()) >>> gapply(df.groupBy("course"), yearlyMedian, newSchema).orderBy("median_earnings").show() +------+----+---------------+ |course|year|median_earnings| +------+----+---------------+ |dotNET|2012| 7500| | Java|2012| 20000| | Java|2013| 30000| |dotNET|2013| 48000| +------+----+---------------+ <BLANKLINE> >>> def twoKeyYearlyMedian(_, vals): ... return pd.DataFrame.from_records([(int(vals["earnings"].median()),)]) >>> newSchema = StructType([df.schema["earnings"]]) >>> gapply(df.groupBy("course", "year"), twoKeyYearlyMedian, newSchema, "earnings").orderBy( ... "earnings").show() +------+----+--------+ |course|year|earnings| +------+----+--------+ |dotNET|2012| 7500| | Java|2012| 20000| | Java|2013| 30000| |dotNET|2013| 48000| +------+----+--------+ <BLANKLINE> >>> spark.stop(); SparkSession._instantiatedContext = None """ import pandas as pd minPandasVersion = '0.7.1' if LooseVersion(pd.__version__) < LooseVersion(minPandasVersion): raise ImportError('Pandas installed but version is {}, {} required' .format(pd.__version__, minPandasVersion)) # Do a null aggregation to retrieve the keys first (should be no computation) # Also consistent with spark.sql.retainGroupColumns keySchema = grouped_data.agg({}).schema keyCols = grouped_data.agg({}).columns if not cols: # Extract the full column list with the parent df javaDFName = "org$apache$spark$sql$RelationalGroupedDataset$$df" parentDF = java_gateway.get_field(grouped_data._jgd, javaDFName) allCols = DataFrame(parentDF, None).columns keyColsSet = set(keyCols) cols = [col for col in allCols if col not in keyColsSet] if "*" in cols: raise ValueError("cols expected to contain only singular columns") if len(set(cols)) < len(cols): raise ValueError("cols expected not to contain duplicate columns") if not isinstance(schema, StructType): raise ValueError("output schema should be a StructType") inputAggDF = grouped_data.agg({col: 'collect_list' for col in cols}) # Recover canonical order (aggregation may change column order) canonicalOrder = chain(keyCols, [inputAggDF['collect_list(' + col + ')'] for col in cols]) inputAggDF = inputAggDF.select(*canonicalOrder) # Wraps the user-provided function with another python function, which prepares the # input in the form specified by the documentation. Then, once the function completes, # this wrapper prepends the keys to the output values and converts from pandas. def pandasWrappedFunc(*args): nvals = len(cols) keys, collectedCols = args[:-nvals], args[-nvals:] paramKeys = tuple(keys) if len(paramKeys) == 1: paramKeys = paramKeys[0] valuesDF = pd.DataFrame.from_dict(dict(zip(cols, collectedCols))) valuesDF = valuesDF[list(cols)] # reorder to canonical outputDF = func(paramKeys, valuesDF) valCols = outputDF.columns.tolist() for key, keyName in zip(keys, keyCols): outputDF[keyName] = key outputDF = outputDF[keyCols + valCols] # reorder to canonical # To recover native python types for serialization, we need # to convert the pandas dataframe to a numpy array, then to a # native list (can't go straight to native, since pandas will # attempt to perserve the numpy type). return outputDF.values.tolist() keyPrependedSchema = StructType(list(chain(keySchema, schema))) outputAggSchema = ArrayType(keyPrependedSchema, containsNull=False) pandasUDF = udf(pandasWrappedFunc, outputAggSchema) outputAggDF = inputAggDF.select(pandasUDF(*inputAggDF)) explodedDF = outputAggDF.select(explode(*outputAggDF).alias("gapply")) # automatically retrieves nested schema column names return explodedDF.select("gapply.*")