Beispiel #1
0
    def _jrdd(self):
        if self._jrdd_val:
            return self._jrdd_val
        func = self.func
        if not self._bypass_serializer and self.ctx.batchSize != 1:
            oldfunc = self.func
            batchSize = self.ctx.batchSize

            def batched_func(split, iterator):
                return batched(oldfunc(split, iterator), batchSize)

            func = batched_func
        cmds = [func, self._bypass_serializer]
        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx._gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        class_manifest = self._prev_jrdd.classManifest()
        env = MapConverter().convert(self.ctx.environment,
                                     self.ctx._gateway._gateway_client)
        python_rdd = self.ctx._jvm.PythonRDD(
            self._prev_jrdd.rdd(), pipe_command, env,
            self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars,
            self.ctx._javaAccumulator, class_manifest)
        self._jrdd_val = python_rdd.asJavaRDD()
        return self._jrdd_val
Beispiel #2
0
    def agg(self, *exprs):
        """ Compute aggregates by specifying a map from column name
        to aggregate methods.

        The available aggregate methods are `avg`, `max`, `min`,
        `sum`, `count`.

        :param exprs: list or aggregate columns or a map from column
                      name to aggregate methods.

        >>> gdf = df.groupBy(df.name)
        >>> gdf.agg({"*": "count"}).collect()
        [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]

        >>> from pyspark.sql import functions as F
        >>> gdf.agg(F.min(df.age)).collect()
        [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
        """
        assert exprs, "exprs should not be empty"
        if len(exprs) == 1 and isinstance(exprs[0], dict):
            jmap = MapConverter().convert(
                exprs[0], self.sql_ctx._sc._gateway._gateway_client)
            jdf = self._jdf.agg(jmap)
        else:
            # Columns
            assert all(isinstance(c, Column)
                       for c in exprs), "all exprs should be Column"
            jcols = ListConverter().convert(
                [c._jc for c in exprs[1:]],
                self.sql_ctx._sc._gateway._gateway_client)
            jdf = self._jdf.agg(exprs[0]._jc,
                                self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
        return DataFrame(jdf, self.sql_ctx)
Beispiel #3
0
    def save(self, path=None, source=None, mode="append", **options):
        """Saves the contents of the :class:`DataFrame` to a data source.

        The data source is specified by the `source` and a set of `options`.
        If `source` is not specified, the default data source configured by
        spark.sql.sources.default will be used.

        Additionally, mode is used to specify the behavior of the save operation when
        data already exists in the data source. There are four modes:

        * append: Contents of this :class:`DataFrame` are expected to be appended to existing data.
        * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
        * error: An exception is expected to be thrown.
        * ignore: The save operation is expected to not save the contents of \
            the :class:`DataFrame` and to not change the existing data.
        """
        if path is not None:
            options["path"] = path
        if source is None:
            source = self.sql_ctx.getConf("spark.sql.sources.default",
                                          "org.apache.spark.sql.parquet")
        jmode = self._java_save_mode(mode)
        joptions = MapConverter().convert(options,
                                          self._sc._gateway._gateway_client)
        self._jdf.save(source, jmode, joptions)
 def submitWorkflowFromFile(self,
                            workflow_xml_file_path,
                            workflow_variables={}):
     workflow_variables_java_map = MapConverter().convert(
         workflow_variables, self.runtime_gateway._gateway_client)
     return self.proactive_scheduler_client.submit(
         self.runtime_gateway.jvm.java.io.File(workflow_xml_file_path),
         workflow_variables_java_map).longValue()
    def fillna(self, value, subset=None):
        """Replace null values, alias for ``na.fill()``.

        :param value: int, long, float, string, or dict.
            Value to replace null values with.
            If the value is a dict, then `subset` is ignored and `value` must be a mapping
            from column name (string) to replacement value. The replacement value must be
            an int, long, float, or string.
        :param subset: optional list of column names to consider.
            Columns specified in subset that do not have matching data type are ignored.
            For example, if `value` is a string, and subset contains a non-string column,
            then the non-string column is simply ignored.

        >>> df4.fillna(50).show()
        age height name
        10  80     Alice
        5   50     Bob
        50  50     Tom
        50  50     null

        >>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
        age height name
        10  80     Alice
        5   null   Bob
        50  null   Tom
        50  null   unknown

        >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
        age height name
        10  80     Alice
        5   null   Bob
        50  null   Tom
        50  null   unknown
        """
        if not isinstance(value, (float, int, long, basestring, dict)):
            raise ValueError(
                "value should be a float, int, long, string, or dict")

        if isinstance(value, (int, long)):
            value = float(value)

        if isinstance(value, dict):
            value = MapConverter().convert(
                value, self.sql_ctx._sc._gateway._gateway_client)
            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
        elif subset is None:
            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
        else:
            if isinstance(subset, basestring):
                subset = [subset]
            elif not isinstance(subset, (list, tuple)):
                raise ValueError(
                    "subset should be a list or tuple of column names")

            cols = ListConverter().convert(
                subset, self.sql_ctx._sc._gateway._gateway_client)
            cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
            return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
 def submitWorkflowFromCatalog(self,
                               bucket_name,
                               workflow_name,
                               workflow_variables={}):
     workflow_variables_java_map = MapConverter().convert(
         workflow_variables, self.runtime_gateway._gateway_client)
     return self.proactive_scheduler_client.submitFromCatalog(
         self.base_url + "/catalog", bucket_name, workflow_name,
         workflow_variables_java_map).longValue()
Beispiel #7
0
    def setParams(self, params):
        """
        Update instance specific MetadataHandle parameters

        :param params: a map of parameters to be set
        """
        gateway = self.spark.sparkContext._gateway
        jmap = MapConverter().convert(params, gateway._gateway_client)
        self.xskipper.setParams(jmap)
Beispiel #8
0
def saveDFToCsv(df, path, hasheader=True, isOverwrite=False, option={}):
    from pyspark.sql import DataFrame
    from py4j.java_collections import MapConverter
    if isinstance(df, DataFrame):
        intp.saveDFToCsv(
            df._jdf, path, hasheader, isOverwrite,
            MapConverter().convert(option, gateway._gateway_client))
    else:
        print(str(df))
Beispiel #9
0
    def getSupportedFiletypes(self):
        types = self.get_supported_filetypes_grouped()
        converted = {}
        for k, v in types.items():
            converted[k] = ListConverter().convert(
                v, Gateway.gateway._gateway_client)

        return MapConverter().convert(converted,
                                      Gateway.gateway._gateway_client)
Beispiel #10
0
def plot_histograms(df, props):
    """
    Plots histograms given in the DataFrame `df` (one from each row), and sets
    the visual properties from `props` onto the plot.

    The DataFrame must have these columns (in any order):
    - `key` (string or `None`): An internal (unique) identifier for each histogram
    - `label`: The name of the histogram, this will appear on the legend, string
    - `binedges`: A NumPy array storing the histograms' bins' edges.
        Must contain one more element than the number of bins.
    - `binvalues`: A NumPy array storing the height of each bin in the histogram
    - `underflows`: The (weighted) sum of the underflowed samples (pseudo-bin before the first real one), float
    - `overflows`: The (weighted) sum of the overflowed samples (pseudo-bin after the last one), float
    - `min`: The minimum of the collected samples, this is the left edge of the underflow "bin", float
    - `max`: The maximum of the collected samples, this is the right edge of the overflow "bin", float

    The properties in `props` will be made specific to the items added by this call,
    by appending `/<key>` row-by-row to each property key.

    This can only be called if the chart is of type `HISTOGRAM`.
    """
    assert_is_native_chart()

    if sorted(list(df.columns)) != sorted([
            "key", "label", "binedges", "binvalues", "underflows", "overflows",
            "min", "max"
    ]):
        raise RuntimeError("Invalid DataFrame format in plot_histogram")

    # the key of each row is appended to the property names in Java
    Gateway.chart_plotter.plotHistograms(
        pl.dumps([
            {
                "key":
                row.key,
                "title":
                str(row.label),

                # this could be computed in Java as well, but just to make things simpler, we do it here
                "sumweights":
                float(np.sum(row.binvalues) + row.underflows + row.overflows),
                "edges":
                _list_to_bytes(row.binedges),
                "values":
                _list_to_bytes(row.binvalues),
                "underflows":
                float(row.underflows),
                "overflows":
                float(row.overflows),
                "min":
                float(row.min),
                "max":
                float(row.max),
            } for row in df.itertuples(index=False)
        ]),
        MapConverter().convert(props, Gateway.gateway._gateway_client))
Beispiel #11
0
    def setConf(sparkSession, params):
        """
        Updates JVM wide xskipper parameters (Only given parameters will be updated)

        :param sparkSession: SparkSession object
        :param params: a map of parameters to be set
        """
        gateway = sparkSession.sparkContext._gateway
        jmap = MapConverter().convert(params, gateway._gateway_client)
        sparkSession._jvm.io.xskipper.Xskipper.setConf(jmap)
Beispiel #12
0
 def _train(data, type, numClasses, categoricalFeaturesInfo,
            impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
            minInfoGain=0.0):
     first = data.first()
     assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
     sc = data.context
     jrdd = data._to_java_object_rdd()
     cfiMap = MapConverter().convert(categoricalFeaturesInfo,
                                     sc._gateway._gateway_client)
     model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
         jrdd, type, numClasses, cfiMap,
         impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
     return DecisionTreeModel(sc, model)
Beispiel #13
0
    def registerFunction(self, name, f, returnType="string"):
        def func(split, iterator):
            return imap(f, iterator)

        command = (func, self._sc.serializer, self._sc.serializer)
        env = MapConverter().convert(self._sc.environment,
                                     self._sc._gateway._gateway_client)
        includes = ListConverter().convert(self._sc._python_includes,
                                           self._sc._gateway._gateway_client)
        self._ssql_ctx.registerPython(
            name, bytearray(CloudPickleSerializer().dumps(command)), env,
            includes, self._sc.pythonExec, self._sc._javaAccumulator,
            returnType)
Beispiel #14
0
    def resolve(self, url, userdata):
        if userdata:
            os.environ['USERDATA'] = userdata
        if 'plugin://' in url or url.startswith('['):
            r = self.resolver.resolve(url)
        else:
            r = self.scan(url)
#		# FIXME: popping USERDATA here causes Py4JException
#		if userdata and userdata == os.getEnv('USERDATA'):
#			os.environ.pop('USERDATA')
        if not r:
            pms.log('unresolved: %s' % (url))
        return MapConverter().convert(r, pms.gateway_client) if r else None
Beispiel #15
0
    def submitWorkflowFromURL(self, workflow_url_spec, workflow_variables={}):
        """
        Submit a job from an url to the scheduler

        :param workflow_url_spec: The workflow url
        :param workflow_variables: The workflow input variables
        :return: The submitted job id
        """
        workflow_variables_java_map = MapConverter().convert(
            workflow_variables, self.runtime_gateway._gateway_client)
        self.logger.debug('Submitting from URL the job \'' +
                          workflow_url_spec + '\'')
        return self.proactive_scheduler_client.submit(
            self.runtime_gateway.jvm.java.net.URL(workflow_url_spec),
            workflow_variables_java_map).longValue()
Beispiel #16
0
    def submitWorkflowFromFile(self,
                               workflow_xml_file_path,
                               workflow_variables={}):
        """
        Submit a job from an xml file to the scheduler

        :param workflow_xml_file_path: The workflow xml file path
        :param workflow_variables: The workflow input variables
        :return: The submitted job id
        """
        workflow_variables_java_map = MapConverter().convert(
            workflow_variables, self.runtime_gateway._gateway_client)
        self.logger.debug('Submitting from file the job \'' +
                          workflow_xml_file_path + '\'')
        return self.proactive_scheduler_client.submit(
            self.runtime_gateway.jvm.java.io.File(workflow_xml_file_path),
            workflow_variables_java_map).longValue()
Beispiel #17
0
 def fromFile(cls,
              sbml_file,
              margins={},
              default_margin=10.,
              input_val=1.,
              current_assignment='universal',
              disable_block_elision=False,
              wiring_switch=False):
     if not wiring_switch:
         return cls(
             gateway.jvm.com.cytocomp.mtt.Wiring2FromFile.apply(
                 sbml_file,
                 MapConverter().convert(margins, gateway._gateway_client),
                 default_margin, input_val, current_assignment,
                 disable_block_elision))
     else:
         return cls(
             gateway.jvm.com.cytocomp.mtt.WiringFromFile.apply(sbml_file))
Beispiel #18
0
def _py2java(sc, obj):
    """ Convert Python object into Java """
    if isinstance(obj, RDD):
        obj = _to_java_object_rdd(obj)
    elif isinstance(obj, SparkContext):
        obj = obj._jsc
    elif isinstance(obj, dict):
        obj = MapConverter().convert(obj, sc._gateway._gateway_client)
    elif isinstance(obj, (list, tuple)):
        obj = ListConverter().convert(obj, sc._gateway._gateway_client)
    elif isinstance(obj, JavaObject):
        pass
    elif isinstance(obj, (int, long, float, bool, basestring)):
        pass
    else:
        bytes = bytearray(PickleSerializer().dumps(obj))
        obj = sc._jvm.SerDe.loads(bytes)
    return obj
Beispiel #19
0
    def _to_java_params(self, field_query, begin_time, end_time):

        gc = self._sc._gateway._gateway_client

        def to_java_list(x):
            if isinstance(x, basestring):
                return ListConverter().convert([x], gc)
            return ListConverter().convert(x, gc)

        java_list_field_query = {
            k: to_java_list(v)
            for k, v in field_query.items()
        }
        java_field_query = MapConverter().convert(java_list_field_query, gc)
        java_begin_time = self._to_java_timestamp(begin_time)
        java_end_time = self._to_java_timestamp(end_time)

        return java_field_query, java_begin_time, java_end_time
Beispiel #20
0
    def submitWorkflowFromCatalog(self,
                                  bucket_name,
                                  workflow_name,
                                  workflow_variables={}):
        """
        Submit a job from the catalog to the scheduler

        :param bucket_name: The bucket in which the workflow is saved
        :param workflow_name: The workflow name
        :param workflow_variables: The workflow input variables
        :return: The submitted job id
        """
        workflow_variables_java_map = MapConverter().convert(
            workflow_variables, self.runtime_gateway._gateway_client)
        self.logger.debug('Submitting from catalog the job \'' + bucket_name +
                          '/' + workflow_name + '\'')
        return self.proactive_scheduler_client.submitFromCatalog(
            self.base_url + "/catalog", bucket_name, workflow_name,
            workflow_variables_java_map).longValue()
Beispiel #21
0
def search_substructure(pattern, molecules):
    if host_os == 'Linux' or host_os == 'Darwin':
        g = JavaGateway.launch_gateway(classpath="{}:{}:{}/".format(py4j_jar_path,
                                                                    os.path.join(cdk_jar_path, 'cdk-2.2.jar'),
                                                                    cdk_jar_path), java_path=java_path)
    elif host_os == 'Windows':
        g = JavaGateway.launch_gateway(classpath="{};{};{}\\".format(py4j_jar_path,
                                                                     os.path.join(cdk_jar_path, 'cdk-2.2.jar'),
                                                                     cdk_jar_path), java_path='java')

    # search_handler = g.jvm.SearchHandler(MapConverter().convert(molecules, g._gateway_client))
    search_handler = g.jvm.SearchHandler()

    matches = search_handler.searchPattern(pattern, MapConverter().convert(molecules, g._gateway_client))

    results = copy.deepcopy([{'id': copy.deepcopy(str(compound_id)), 'match_count': copy.deepcopy(int(match_count)),
                              'svg': copy.deepcopy(str(svg))}
                             for compound_id, match_count, svg in matches])
    g.shutdown()
    return results
Beispiel #22
0
    def addCustomIndex(self, indexClass, cols, params, keyMetadata=None):
        """
        Adds a Custom index on the given columns

        :param cols: a sequence of cols
        :param params: a map of index specific parameters
        :param keyMetadata: optional key metadata
        """
        gateway = self.spark.sparkContext._gateway
        jmap = MapConverter().convert(params, gateway._gateway_client)
        objCls = gateway.jvm.String
        colsArr = gateway.new_array(objCls, len(cols))
        for i in range(len(cols)):
            colsArr[i] = cols[i]
        if keyMetadata:
            self._jindexBuilder.addCustomIndex(indexClass, colsArr, jmap,
                                               keyMetadata)
        else:
            self._jindexBuilder.addCustomIndex(indexClass, colsArr, jmap)
        return self
    def saveAsTable(self, tableName, source=None, mode="error", **options):
        """Saves the contents of this :class:`DataFrame` to a data source as a table.

        The data source is specified by the ``source`` and a set of ``options``.
        If ``source`` is not specified, the default data source configured by
        ``spark.sql.sources.default`` will be used.

        Additionally, mode is used to specify the behavior of the saveAsTable operation when
        table already exists in the data source. There are four modes:

        * `append`: Append contents of this :class:`DataFrame` to existing data.
        * `overwrite`: Overwrite existing data.
        * `error`: Throw an exception if data already exists.
        * `ignore`: Silently ignore this operation if data already exists.
        """
        if source is None:
            source = self.sql_ctx.getConf("spark.sql.sources.default",
                                          "org.apache.spark.sql.parquet")
        jmode = self._java_save_mode(mode)
        joptions = MapConverter().convert(
            options, self.sql_ctx._sc._gateway._gateway_client)
        self._jdf.saveAsTable(tableName, source, jmode, joptions)
Beispiel #24
0
    def trainClassifier(data,
                        numClasses,
                        categoricalFeaturesInfo,
                        impurity="gini",
                        maxDepth=5,
                        maxBins=32,
                        minInstancesPerNode=1,
                        minInfoGain=0.0):
        """
        Train a DecisionTreeModel for classification.

        :param data: Training data: RDD of LabeledPoint.
                     Labels are integers {0,1,...,numClasses}.
        :param numClasses: Number of classes for classification.
        :param categoricalFeaturesInfo: Map from categorical feature index
                                        to number of categories.
                                        Any feature not in this map
                                        is treated as continuous.
        :param impurity: Supported values: "entropy" or "gini"
        :param maxDepth: Max depth of tree.
                         E.g., depth 0 means 1 leaf node.
                         Depth 1 means 1 internal node + 2 leaf nodes.
        :param maxBins: Number of bins used for finding splits at each node.
        :param minInstancesPerNode: Min number of instances required at child nodes to create
                                    the parent split
        :param minInfoGain: Min info gain required to create a split
        :return: DecisionTreeModel
        """
        sc = data.context
        dataBytes = _get_unmangled_labeled_point_rdd(data)
        categoricalFeaturesInfoJMap = \
            MapConverter().convert(categoricalFeaturesInfo,
                                   sc._gateway._gateway_client)
        model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
            dataBytes._jrdd, "classification", numClasses,
            categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins,
            minInstancesPerNode, minInfoGain)
        dataBytes.unpersist()
        return DecisionTreeModel(sc, model)
Beispiel #25
0
    def test_kafka_stream(self):
        """Test the Python Kafka stream API."""
        topic = "topic1"
        sendData = {"a": 3, "b": 5, "c": 10}
        jSendData = MapConverter().convert(
            sendData, self.ssc.sparkContext._gateway._gateway_client)

        self._kafkaTestUtils.createTopic(topic)
        self._kafkaTestUtils.sendMessages(topic, jSendData)

        stream = KafkaUtils.createStream(self.ssc,
                                         self._kafkaTestUtils.zkAddress(),
                                         "test-streaming-consumer", {topic: 1},
                                         {"auto.offset.reset": "smallest"})

        result = {}
        for i in chain.from_iterable(
                self._collect(stream.map(lambda x: x[1]),
                              sum(sendData.values()))):
            result[i] = result.get(i, 0) + 1

        self.assertEqual(sendData, result)
Beispiel #26
0
def plot_lines(df, props):
    """
    Plots lines given in the DataFrame `df` (one line from each row), and sets
    the visual properties from `props` onto the plot.

    The DataFrame must have these columns (in any order):
    - `key` (string or `None`): An internal (unique) identifier for each line.
    - `label` (string): The name of the line, this will appear on the legend, string.
    - `xs` (array): Stores the X coordinates of the points used to draw the line.
    - `ys` (array): Stores the Y coordinates of the points used to draw the line.
        Must have the same number of elements as `xs`

    The properties in `props` will be made specific to the items added by this call,
    by appending `/<key>` row-by-row to each property key.

    This can only be called if the chart is of type `LINE`.
    """
    assert_is_native_chart()
    if sorted(list(df.columns)) != sorted(["key", "label", "xs", "ys"]):
        raise RuntimeError("Invalid DataFrame format in plot_lines")

    # only used on posix, to unlink them later
    shm_objs = list()
    # only used on windows, to prevent gc
    mmap_objs = list()

    # the key of each row is appended to the property names in Java
    Gateway.chart_plotter.plotVectors(
        pl.dumps([{
            "key": row.key,
            "title": str(row.label),
            "xs": _put_array_in_shm(row.xs, shm_objs, mmap_objs),
            "ys": _put_array_in_shm(row.ys, shm_objs, mmap_objs)
        } for row in df.itertuples(index=False)]),
        MapConverter().convert(props, Gateway.gateway._gateway_client))

    # this is a no-op on Windows
    for o in shm_objs:
        o.unlink()
Beispiel #27
0
    def getJavaVertexRDD(self, rdd, rdd_deserializer):
        if self.bypass_serializer:
            self.jvertex_rdd_deserializer = NoOpSerializer()
            rdd_deserializer = NoOpSerializer()
        # enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
        # profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
        def f(index, iterator):
            return iterator

        command = (f, rdd_deserializer, rdd_deserializer)
        # the serialized command will be compressed by broadcast
        ser = CloudPickleSerializer()
        pickled_command = ser.dumps(command)
        if len(pickled_command) > (1 << 20):  # 1M
            self.broadcast = self.ctx.broadcast(pickled_command)
            pickled_command = ser.dumps(self.broadcast)

        # the serialized command will be compressed by broadcast
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx._gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        env = MapConverter().convert(self.ctx.environment,
                                     self.ctx._gateway._gateway_client)
        includes = ListConverter().convert(self.ctx._python_includes,
                                           self.ctx._gateway._gateway_client)
        target_storage_level = StorageLevel.MEMORY_ONLY
        java_storage_level = self.ctx._getJavaStorageLevel(
            target_storage_level)
        prdd = self.ctx._jvm.PythonVertexRDD(
            rdd._jrdd, bytearray(pickled_command), env, includes,
            self.preserve_partitioning, self.ctx.pythonExec, broadcast_vars,
            self.ctx._javaAccumulator, java_storage_level)
        self.jvertex_rdd = prdd.asJavaVertexRDD()
        # if enable_profile:
        #     self.id = self.jvertex_rdd.id()
        #     self.ctx._add_profile(self.id, profileStats)
        return self.jvertex_rdd
Beispiel #28
0
def plot_lines(df, props = dict()):  # key, label, xs, ys
    _assert_is_native_chart()
    if sorted(list(df.columns)) != sorted(["key", "label", "xs", "ys"]):
        raise RuntimeError("Invalid DataFrame format in plot_lines")

    # only used on posix, to unlink them later
    shm_objs = list()
    # only used on windows, to prevent gc
    mmap_objs = list()

    Gateway.chart_plotter.plotVectors(pl.dumps([
        {
            "key": row.key,
            "title": row.label,
            "xs": _put_array_in_shm(row.xs, shm_objs, mmap_objs),
            "ys": _put_array_in_shm(row.ys, shm_objs, mmap_objs)
        }
        for row in df.itertuples(index=False)
    ]), MapConverter().convert(props, Gateway.gateway._gateway_client))

    # this is a no-op on Windows
    for o in shm_objs:
        o.unlink()
    def agg(self, *exprs):
        """Compute aggregates and returns the result as a :class:`DataFrame`.

        The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.

        If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
        is the column to perform aggregation on, and the value is the aggregate function.

        Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.

        :param exprs: a dict mapping from column name (string) to aggregate functions (string),
            or a list of :class:`Column`.

        >>> gdf = df.groupBy(df.name)
        >>> gdf.agg({"*": "count"}).collect()
        [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]

        >>> from pyspark.sql import functions as F
        >>> gdf.agg(F.min(df.age)).collect()
        [Row(MIN(age)=5), Row(MIN(age)=2)]
        """
        assert exprs, "exprs should not be empty"
        if len(exprs) == 1 and isinstance(exprs[0], dict):
            jmap = MapConverter().convert(
                exprs[0], self.sql_ctx._sc._gateway._gateway_client)
            jdf = self._jdf.agg(jmap)
        else:
            # Columns
            assert all(isinstance(c, Column)
                       for c in exprs), "all exprs should be Column"
            jcols = ListConverter().convert(
                [c._jc for c in exprs[1:]],
                self.sql_ctx._sc._gateway._gateway_client)
            jdf = self._jdf.agg(exprs[0]._jc,
                                self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
        return DataFrame(jdf, self.sql_ctx)
Beispiel #30
0
 def _jrdd(self):
     if self._jrdd_val:
         return self._jrdd_val
     if self._bypass_serializer:
         serializer = NoOpSerializer()
     else:
         serializer = self.ctx.serializer
     command = (self.func, self._prev_jrdd_deserializer, serializer)
     pickled_command = CloudPickleSerializer().dumps(command)
     broadcast_vars = ListConverter().convert(
         [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
         self.ctx._gateway._gateway_client)
     self.ctx._pickled_broadcast_vars.clear()
     class_tag = self._prev_jrdd.classTag()
     env = MapConverter().convert(self.ctx.environment,
                                  self.ctx._gateway._gateway_client)
     includes = ListConverter().convert(self.ctx._python_includes,
                                  self.ctx._gateway._gateway_client)
     python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
         bytearray(pickled_command), env, includes, self.preservesPartitioning,
         self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
         class_tag)
     self._jrdd_val = python_rdd.asJavaRDD()
     return self._jrdd_val