Ejemplo n.º 1
0
 def optimize(self):
     """
     Do an optimization.
     """
     jmodel = callJavaFunc(get_spark_context(), self.value.optimize)
     from bigdl.nn.layer import Layer
     return Layer.of(jmodel)
Ejemplo n.º 2
0
 def predict(self, x, batch_per_thread=-1, distributed=True):
     """
     Use a model to do prediction.
     """
     if isinstance(x, ImageSet):
         results = callBigDlFunc(self.bigdl_type, "zooPredict", self.value,
                                 x, batch_per_thread)
         return ImageSet(results)
     if distributed:
         if isinstance(x, np.ndarray):
             data_rdd = to_sample_rdd(x, np.zeros([x.shape[0]]),
                                      get_spark_context())
         elif isinstance(x, RDD):
             data_rdd = x
         else:
             raise TypeError("Unsupported prediction data type: %s" %
                             type(x))
         results = callBigDlFunc(self.bigdl_type, "zooPredict", self.value,
                                 data_rdd, batch_per_thread)
         return results.map(lambda result: Layer.convert_output(result))
     else:
         if isinstance(x, np.ndarray) or isinstance(x, list):
             results = callBigDlFunc(self.bigdl_type,
                                     "zooPredict", self.value,
                                     self._to_jtensors(x), batch_per_thread)
             return [Layer.convert_output(result) for result in results]
         else:
             raise TypeError("Unsupported prediction data type: %s" %
                             type(x))
Ejemplo n.º 3
0
 def evaluate(self, x, y, batch_size=32, sample_weight=None, is_distributed=False):
     """
     Evaluate a model by the given metrics.
     :param x: ndarray or list of ndarray for local mode.
               RDD[Sample] for distributed mode
     :param y: ndarray or list of ndarray for local mode and would be None for cluster mode.
     :param batch_size
     :param is_distributed: run in local mode or distributed mode.
            NB: if is_distributed=true, x should be RDD[Sample] and y should be None
     :return:
     """
     if sample_weight:
         unsupport_exp("sample_weight")
     if is_distributed:
         if isinstance(x, np.ndarray):
             input = to_sample_rdd(x, y)
         elif isinstance(x, RDD):
             input = x
         if self.metrics:
             sc = get_spark_context()
             return [r.result for r in
                     self.bmodel.evaluate(input, batch_size, self.metrics)]
         else:
             raise Exception("No Metrics found.")
     else:
         raise Exception("We only support evaluation in distributed mode")
Ejemplo n.º 4
0
def _java2py(gateway, r, encoding="bytes"):
    from py4j.protocol import Py4JJavaError
    from py4j.java_gateway import JavaObject
    from py4j.java_collections import JavaArray, JavaList, JavaMap
    from pyspark import RDD
    from pyspark.serializers import PickleSerializer
    from pyspark.sql import DataFrame
    from bigdl.util.common import get_spark_context, _picklable_classes, get_spark_sql_context

    if isinstance(r, JavaObject):
        clsName = r.getClass().getSimpleName()
        # convert RDD into JavaRDD
        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
            r = r.toJavaRDD()
            clsName = 'JavaRDD'

        if clsName == 'JavaRDD':
            jrdd = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.javaToPython(
                r)
            return RDD(jrdd, get_spark_context())

        if clsName == 'DataFrame':
            return DataFrame(r, get_spark_sql_context(get_spark_context()))

        if clsName == 'Dataset':
            return DataFrame(r, get_spark_sql_context(get_spark_context()))

        if clsName == "ImageFrame[]":
            return r

        if clsName in _picklable_classes:
            r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
                r)
        elif isinstance(r, (JavaArray, JavaList)) and len(r) != 0 \
                and isinstance(r[0], JavaObject) \
                and r[0].getClass().getSimpleName() in ['DataFrame', 'Dataset']:
            spark = get_spark_sql_context(get_spark_context())
            r = list(map(lambda x: DataFrame(x, spark), r))
        elif isinstance(r, (JavaArray, JavaList, JavaMap)):
            try:
                r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
                    r)
            except Py4JJavaError:
                pass  # not pickable
        if isinstance(r, (bytearray, bytes)):
            r = PickleSerializer().loads(bytes(r), encoding=encoding)
    return r
Ejemplo n.º 5
0
def _java2py(gateway, r, encoding="bytes"):
    from py4j.protocol import Py4JJavaError
    from py4j.java_gateway import JavaObject
    from py4j.java_collections import ListConverter, JavaArray, JavaList, JavaMap, MapConverter
    from py4j.java_gateway import JavaGateway, GatewayClient
    from pyspark import RDD, SparkContext
    from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
    from pyspark.sql import DataFrame, SQLContext
    from pyspark.mllib.common import callJavaFunc
    from pyspark import SparkConf
    from pyspark.files import SparkFiles
    from bigdl.util.common import get_spark_context, _picklable_classes, get_spark_sql_context
    if isinstance(r, JavaObject):
        clsName = r.getClass().getSimpleName()
        # convert RDD into JavaRDD
        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
            r = r.toJavaRDD()
            clsName = 'JavaRDD'

        if clsName == 'JavaRDD':
            jrdd = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.javaToPython(
                r)
            return RDD(jrdd, get_spark_context())

        if clsName == 'DataFrame':
            return DataFrame(r, get_spark_sql_context(get_spark_context()))

        if clsName == 'Dataset':
            return DataFrame(r, get_spark_sql_context(get_spark_context()))

        if clsName == "ImageFrame[]":
            return r

        if clsName in _picklable_classes:
            r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
                r)
        elif isinstance(r, (JavaArray, JavaList, JavaMap)):
            try:
                r = gateway.jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(
                    r)
            except Py4JJavaError:
                pass  # not pickable
        if isinstance(r, (bytearray, bytes)):
            r = PickleSerializer().loads(bytes(r), encoding=encoding)
    return r
Ejemplo n.º 6
0
 def __create_distributed_optimizer(self, training_rdd,
                                    batch_size=32,
                                    nb_epoch=10,
                                    validation_data=None):
     sc = get_spark_context()
     bopt = boptimizer.Optimizer(
         model=self.bmodel,
         training_rdd=training_rdd,
         criterion=self.criterion,
         end_trigger=boptimizer.MaxEpoch(nb_epoch),
         batch_size=batch_size,
         optim_method=self.optim_method
     )
     if validation_data:
         bopt.set_validation(batch_size,
                             val_rdd=validation_data,
                             # TODO: check if keras use the same strategy
                             trigger=boptimizer.EveryEpoch(),
                             val_method=self.metrics)
     return bopt