Exemplo n.º 1
0
    def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row_size, dataset_idx=None):
        self._check_params(metadata)

        run_id = self.getRunId()
        if run_id is None:
            run_id = 'pytorch_' + str(int(time.time()))

        model = self.getModel()
        is_legacy = not isinstance(model, LightningModule)
        if is_legacy:
            # Legacy: convert params to LightningModule
            model = to_lightning_module(model=self.getModel(),
                                        optimizer=self._get_optimizer(),
                                        loss_fns=self.getLoss(),
                                        loss_weights=self.getLossWeights(),
                                        feature_cols=self.getFeatureCols(),
                                        label_cols=self.getLabelCols(),
                                        sample_weights_col=self.getSampleWeightCol(),
                                        validation=self.getValidation())

        serialized_model = serialize_fn()(model)
        # FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator.
        ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None
        trainer = remote.RemoteTrainer(self,
                                       metadata=metadata,
                                       ckpt_bytes=ckpt_bytes,
                                       run_id=run_id,
                                       dataset_idx=dataset_idx,
                                       train_rows=train_rows,
                                       val_rows=val_rows,
                                       avg_row_size=avg_row_size,
                                       is_legacy=is_legacy)
        handle = backend.run(trainer, args=(serialized_model,), env={})
        return self._create_model(handle, run_id, metadata)
Exemplo n.º 2
0
def _torch_param_serialize(param_name, param_val):
    if param_val is None:
        return None

    if param_name in [EstimatorParams.backend.name, EstimatorParams.store.name]:
        # We do not serialize backend and store. These params have to be regenerated for each
        # run of the pipeline
        return None
    elif param_name == EstimatorParams.model.name:
        serialize = serialize_fn()
        return serialize(param_val)

    return codec.dumps_base64(param_val)
Exemplo n.º 3
0
    def _transform(self, df):
        import copy
        from pyspark.sql.types import StructField, StructType
        from pyspark.ml.linalg import VectorUDT

        model_pre_predict = self.getModel()
        deserialize = deserialize_fn()
        serialize = serialize_fn()
        serialized_model = serialize(model_pre_predict)

        input_shapes = self.getInputShapes()
        label_cols = self.getLabelColumns()
        output_cols = self.getOutputCols()
        feature_cols = self.getFeatureColumns()
        metadata = self._get_metadata()

        final_output_cols = util.get_output_cols(df.schema, output_cols)

        def predict(rows):
            from pyspark import Row
            from pyspark.ml.linalg import DenseVector, SparseVector

            model = deserialize(serialized_model)
            # Perform predictions.
            for row in rows:
                fields = row.asDict().copy()

                # Note: if the col is SparseVector, torch.tensor(col) correctly converts it to a
                # dense torch tensor.
                data = [
                    torch.tensor([row[col]]).reshape(shape)
                    for col, shape in zip(feature_cols, input_shapes)
                ]

                with torch.no_grad():
                    preds = model(*data)

                if not isinstance(preds, list) and not isinstance(
                        preds, tuple):
                    preds = [preds]

                for label_col, output_col, pred in zip(label_cols, output_cols,
                                                       preds):
                    meta = metadata[label_col]
                    col_type = meta['spark_data_type']
                    # dtype for dense and spark tensor is always np.float64
                    if col_type == DenseVector:
                        shape = np.prod(pred.shape)
                        flattened_pred = pred.reshape(shape, )
                        field = DenseVector(flattened_pred)
                    elif col_type == SparseVector:
                        shape = meta['shape']
                        flattened_pred = pred.reshape(shape, )
                        nonzero_indices = flattened_pred.nonzero()[0]
                        field = SparseVector(shape, nonzero_indices,
                                             flattened_pred[nonzero_indices])
                    elif pred.shape.numel() == 1:
                        # If the column is scalar type, int, float, etc.
                        value = pred.item()
                        python_type = util.spark_scalar_to_python_type(
                            col_type)
                        if issubclass(python_type, numbers.Integral):
                            value = round(value)
                        field = python_type(value)
                    else:
                        field = DenseVector(pred.reshape(-1))

                    fields[output_col] = field

                values = [fields[col] for col in final_output_cols]

                yield Row(*values)

        spark0 = SparkSession._instantiatedSession

        final_output_fields = []

        # copy input schema
        for field in df.schema.fields:
            final_output_fields.append(copy.deepcopy(field))

        # append output schema
        override_fields = df.limit(1).rdd.mapPartitions(
            predict).toDF().schema.fields[-len(output_cols):]
        for name, override, label in zip(output_cols, override_fields,
                                         label_cols):
            # default data type as label type
            data_type = metadata[label]['spark_data_type']()

            if type(override.dataType) == VectorUDT:
                # Override output to vector. This is mainly for torch's classification loss
                # where label is a scalar but model output is a vector.
                data_type = VectorUDT()
            final_output_fields.append(
                StructField(name=name, dataType=data_type, nullable=True))

        final_output_schema = StructType(final_output_fields)

        pred_rdd = df.rdd.mapPartitions(predict)

        # Use the schema from previous section to construct the final DF with prediction
        return spark0.createDataFrame(pred_rdd, schema=final_output_schema)