def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row_size, dataset_idx=None): self._check_params(metadata) run_id = self.getRunId() if run_id is None: run_id = 'pytorch_' + str(int(time.time())) model = self.getModel() is_legacy = not isinstance(model, LightningModule) if is_legacy: # Legacy: convert params to LightningModule model = to_lightning_module(model=self.getModel(), optimizer=self._get_optimizer(), loss_fns=self.getLoss(), loss_weights=self.getLossWeights(), feature_cols=self.getFeatureCols(), label_cols=self.getLabelCols(), sample_weights_col=self.getSampleWeightCol(), validation=self.getValidation()) serialized_model = serialize_fn()(model) # FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator. ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None trainer = remote.RemoteTrainer(self, metadata=metadata, ckpt_bytes=ckpt_bytes, run_id=run_id, dataset_idx=dataset_idx, train_rows=train_rows, val_rows=val_rows, avg_row_size=avg_row_size, is_legacy=is_legacy) handle = backend.run(trainer, args=(serialized_model,), env={}) return self._create_model(handle, run_id, metadata)
def _torch_param_serialize(param_name, param_val): if param_val is None: return None if param_name in [EstimatorParams.backend.name, EstimatorParams.store.name]: # We do not serialize backend and store. These params have to be regenerated for each # run of the pipeline return None elif param_name == EstimatorParams.model.name: serialize = serialize_fn() return serialize(param_val) return codec.dumps_base64(param_val)
def _transform(self, df): import copy from pyspark.sql.types import StructField, StructType from pyspark.ml.linalg import VectorUDT model_pre_predict = self.getModel() deserialize = deserialize_fn() serialize = serialize_fn() serialized_model = serialize(model_pre_predict) input_shapes = self.getInputShapes() label_cols = self.getLabelColumns() output_cols = self.getOutputCols() feature_cols = self.getFeatureColumns() metadata = self._get_metadata() final_output_cols = util.get_output_cols(df.schema, output_cols) def predict(rows): from pyspark import Row from pyspark.ml.linalg import DenseVector, SparseVector model = deserialize(serialized_model) # Perform predictions. for row in rows: fields = row.asDict().copy() # Note: if the col is SparseVector, torch.tensor(col) correctly converts it to a # dense torch tensor. data = [ torch.tensor([row[col]]).reshape(shape) for col, shape in zip(feature_cols, input_shapes) ] with torch.no_grad(): preds = model(*data) if not isinstance(preds, list) and not isinstance( preds, tuple): preds = [preds] for label_col, output_col, pred in zip(label_cols, output_cols, preds): meta = metadata[label_col] col_type = meta['spark_data_type'] # dtype for dense and spark tensor is always np.float64 if col_type == DenseVector: shape = np.prod(pred.shape) flattened_pred = pred.reshape(shape, ) field = DenseVector(flattened_pred) elif col_type == SparseVector: shape = meta['shape'] flattened_pred = pred.reshape(shape, ) nonzero_indices = flattened_pred.nonzero()[0] field = SparseVector(shape, nonzero_indices, flattened_pred[nonzero_indices]) elif pred.shape.numel() == 1: # If the column is scalar type, int, float, etc. value = pred.item() python_type = util.spark_scalar_to_python_type( col_type) if issubclass(python_type, numbers.Integral): value = round(value) field = python_type(value) else: field = DenseVector(pred.reshape(-1)) fields[output_col] = field values = [fields[col] for col in final_output_cols] yield Row(*values) spark0 = SparkSession._instantiatedSession final_output_fields = [] # copy input schema for field in df.schema.fields: final_output_fields.append(copy.deepcopy(field)) # append output schema override_fields = df.limit(1).rdd.mapPartitions( predict).toDF().schema.fields[-len(output_cols):] for name, override, label in zip(output_cols, override_fields, label_cols): # default data type as label type data_type = metadata[label]['spark_data_type']() if type(override.dataType) == VectorUDT: # Override output to vector. This is mainly for torch's classification loss # where label is a scalar but model output is a vector. data_type = VectorUDT() final_output_fields.append( StructField(name=name, dataType=data_type, nullable=True)) final_output_schema = StructType(final_output_fields) pred_rdd = df.rdd.mapPartitions(predict) # Use the schema from previous section to construct the final DF with prediction return spark0.createDataFrame(pred_rdd, schema=final_output_schema)