def _deserialize_dict(self, dict_values): deserialized_dict = dict() for key, val in dict_values.items(): if val is None: deserialized_dict[key] = None elif key == EstimatorParams.model.name: deserialize = deserialize_fn() deserialized_dict[key] = deserialize(val) else: deserialized_dict[key] = codec.loads_base64(val) return deserialized_dict
def _transform(self, df): import copy from pyspark.sql.types import StructField, StructType from pyspark.ml.linalg import VectorUDT model_pre_predict = self.getModel() deserialize = deserialize_fn() serialize = serialize_fn() serialized_model = serialize(model_pre_predict) input_shapes = self.getInputShapes() label_cols = self.getLabelColumns() output_cols = self.getOutputCols() feature_cols = self.getFeatureColumns() metadata = self._get_metadata() final_output_cols = util.get_output_cols(df.schema, output_cols) def predict(rows): from pyspark import Row from pyspark.ml.linalg import DenseVector, SparseVector model = deserialize(serialized_model) # Perform predictions. for row in rows: fields = row.asDict().copy() # Note: if the col is SparseVector, torch.tensor(col) correctly converts it to a # dense torch tensor. data = [ torch.tensor([row[col]]).reshape(shape) for col, shape in zip(feature_cols, input_shapes) ] with torch.no_grad(): preds = model(*data) if not isinstance(preds, list) and not isinstance( preds, tuple): preds = [preds] for label_col, output_col, pred in zip(label_cols, output_cols, preds): meta = metadata[label_col] col_type = meta['spark_data_type'] # dtype for dense and spark tensor is always np.float64 if col_type == DenseVector: shape = np.prod(pred.shape) flattened_pred = pred.reshape(shape, ) field = DenseVector(flattened_pred) elif col_type == SparseVector: shape = meta['shape'] flattened_pred = pred.reshape(shape, ) nonzero_indices = flattened_pred.nonzero()[0] field = SparseVector(shape, nonzero_indices, flattened_pred[nonzero_indices]) elif pred.shape.numel() == 1: # If the column is scalar type, int, float, etc. value = pred.item() python_type = util.spark_scalar_to_python_type( col_type) if issubclass(python_type, numbers.Integral): value = round(value) field = python_type(value) else: field = DenseVector(pred.reshape(-1)) fields[output_col] = field values = [fields[col] for col in final_output_cols] yield Row(*values) spark0 = SparkSession._instantiatedSession final_output_fields = [] # copy input schema for field in df.schema.fields: final_output_fields.append(copy.deepcopy(field)) # append output schema override_fields = df.limit(1).rdd.mapPartitions( predict).toDF().schema.fields[-len(output_cols):] for name, override, label in zip(output_cols, override_fields, label_cols): # default data type as label type data_type = metadata[label]['spark_data_type']() if type(override.dataType) == VectorUDT: # Override output to vector. This is mainly for torch's classification loss # where label is a scalar but model output is a vector. data_type = VectorUDT() final_output_fields.append( StructField(name=name, dataType=data_type, nullable=True)) final_output_schema = StructType(final_output_fields) pred_rdd = df.rdd.mapPartitions(predict) # Use the schema from previous section to construct the final DF with prediction return spark0.createDataFrame(pred_rdd, schema=final_output_schema)
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_rows, val_rows, avg_row_size, is_legacy): # Estimator parameters input_shapes = estimator.getInputShapes() label_shapes = estimator.getLabelShapes() feature_columns = estimator.getFeatureCols() label_columns = estimator.getLabelCols() sample_weight_col = estimator.getSampleWeightCol() should_validate = estimator.getValidation() batch_size = estimator.getBatchSize() val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize( ) else batch_size epochs = estimator.getEpochs() random_seed = estimator.getRandomSeed() user_shuffle_buffer_size = estimator.getShufflingBufferSize() terminate_on_nan = estimator.getTerminateOnNan() transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() or [] train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() data_module = estimator.getDataModule() if estimator.getDataModule( ) else PetastormDataModule loader_num_epochs = estimator.getLoaderNumEpochs() verbose = (estimator.getVerbose() > 0) trainer_args = estimator.getTrainerArgs() debug_data_loader = estimator.getDebugDataLoader() train_async_data_loader_queue_size = estimator.getTrainAsyncDataLoaderQueueSize( ) val_async_data_loader_queue_size = estimator.getValAsyncDataLoaderQueueSize( ) # get logger logger = estimator.getLogger() log_every_n_steps = estimator.getLogEveryNSteps() print(f"logger is configured: {logger}") # Comet logger's expriment key is not serialize correctly. Need to remember the key, and # resume the logger experiment from GPU instance. if isinstance(logger, CometLogger): logger_experiment_key = logger._experiment_key print(f"logger vars: {vars(logger)}") else: logger_experiment_key = None # Data reader parameters train_reader_worker_count = estimator.getTrainReaderNumWorker() val_reader_worker_count = estimator.getValReaderNumWorker() reader_pool_type = estimator.getReaderPoolType() # Utility functions deserialize = deserialize_fn() calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn( train_rows, avg_row_size, user_shuffle_buffer_size) schema_fields = feature_columns + label_columns if sample_weight_col: schema_fields.append(sample_weight_col) # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) storage_options = store.storage_options profiler = estimator.getProfiler() def train(serialized_model): import horovod.torch as hvd if random_seed is not None: pl.utilities.seed.seed_everything(seed=random_seed) # Horovod: initialize library. hvd.init() if verbose: import horovod as _horovod print( f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}" ) _checkpoint_callback = None require_checkpoint = False with remote_store.get_local_output_dir() as run_output_dir: logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) os.makedirs(logs_path, exist_ok=True) print(f"Made directory {logs_path} for horovod rank {hvd.rank()}") ckpt_dir = run_output_dir ckpt_filename = remote_store.checkpoint_filename if logger is None: # Use default logger if no logger is supplied train_logger = TensorBoardLogger(logs_path) print(f"Setup logger: Using TensorBoardLogger: {train_logger}") elif isinstance(logger, CometLogger): if logger._experiment_key: # use logger passed in. train_logger = logger train_logger._save_dir = logs_path print( f"Setup logger: change save_dir of the logger to {logs_path}" ) elif logger_experiment_key: # Resume logger experiment with new log path if key passed correctly from CPU. train_logger = CometLogger( save_dir=logs_path, api_key=logger.api_key, experiment_key=logger_experiment_key, ) print( f"Setup logger: Resume comet logger: {vars(train_logger)}" ) else: print( f"Failed to setup or resume comet logger. origin logger: {vars(logger)}" ) else: # use logger passed in. train_logger = logger train_logger.save_dir = logs_path print( f"Setup logger: Using logger passed from estimator: {train_logger}" ) # Lightning requires to add checkpoint callbacks for all ranks. # Otherwise we are seeing hanging in training. for cb in callbacks: if isinstance(cb, ModelCheckpoint): cb.dirpath = ckpt_dir cb.filename = ckpt_filename _checkpoint_callback = cb require_checkpoint = True break if not _checkpoint_callback: # By default 'monitor'=None which saves a checkpoint only for the last epoch. _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir, filename=ckpt_filename, verbose=True) callbacks.append(_checkpoint_callback) if remote_store.saving_runs and hvd.rank() == 0: # Horovod: sync checkpoint and logging files only on rank 0 to # prevent other ranks from corrupting them. class _SyncCallback(Callback): def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: remote_store.sync(run_output_dir) callbacks.append(_SyncCallback()) model = deserialize(serialized_model) _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \ int(math.floor(float(train_rows) / batch_size / hvd.size())) _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \ int(math.floor(float(val_rows) / val_batch_size / hvd.size())) shuffle_size = calculate_shuffle_buffer_size() if verbose: print( f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}, " f"Shuffle_size: {shuffle_size}, Random seed: {random_seed}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n" f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n" f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n" ) cuda_available = torch.cuda.is_available() # We need to check all ranks have same device type for traning. # Horovod doesn't support heterogeneous allreduce for gradients. cuda_avail_list = hvd.allgather_object(cuda_available, name='device type') if cuda_avail_list.count(cuda_available) != hvd.size(): raise RuntimeError("All ranks don't have same device type!") if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device( _get_assigned_gpu_or_default(default=hvd.local_rank())) # Move model to GPU. model.cuda() _num_gpus = num_gpus if _num_gpus is None: _num_gpus = 1 if cuda_available else 0 # Set bar refresh to 1 / epoch, detailed loss and metrics is avaialbe in logger, # no need to print in screen here. User can still override this in trainer_args progress_bar_refresh_rate = _train_steps_per_epoch kwargs = { 'accelerator': 'horovod', 'gpus': _num_gpus, 'callbacks': callbacks, 'max_epochs': epochs, 'logger': train_logger, 'log_every_n_steps': log_every_n_steps, 'num_sanity_val_steps': 0, 'reload_dataloaders_every_epoch': False, 'progress_bar_refresh_rate': progress_bar_refresh_rate, 'terminate_on_nan': terminate_on_nan, 'profiler': profiler } if trainer_args: kwargs.update(trainer_args) if verbose and hvd.rank() == 0: print("Creating trainer with: \n ", kwargs) trainer = Trainer(**kwargs) if profiler != 'simple' and trainer.profiler: print( f"Set profiler's logs_path for {hvd.rank()} to {logs_path}" ) trainer.profiler.dirpath = logs_path # filename where the profiler results will be saved instead of # printing to stdout. The .txt extension will be used automatically. trainer.profiler.filename = "profile" if verbose and hvd.rank() == 0: print(f"pytorch_lightning version={pl.__version__}") data_module_kwargs = { 'train_dir': remote_store.train_data_path, 'val_dir': remote_store.val_data_path, 'num_train_epochs': epochs, 'has_val': should_validate is not None, 'train_batch_size': batch_size, 'val_batch_size': val_batch_size, 'shuffle_size': shuffle_size, 'num_reader_epochs': loader_num_epochs, 'reader_pool_type': reader_pool_type, 'reader_worker_count': train_reader_worker_count, 'transform_spec': transformation, 'inmemory_cache_all': inmemory_cache_all, 'cur_shard': hvd.rank(), 'shard_count': hvd.size(), 'schema_fields': schema_fields, 'storage_options': storage_options, 'steps_per_epoch_train': _train_steps_per_epoch, 'steps_per_epoch_val': _val_steps_per_epoch, 'verbose': verbose, 'debug_data_loader': debug_data_loader, 'train_async_data_loader_queue_size': train_async_data_loader_queue_size, 'val_async_data_loader_queue_size': val_async_data_loader_queue_size, } if debug_data_loader and hvd.rank() == 0: print( f"Creating data module with args:\n {data_module_kwargs}") dataset = data_module(**data_module_kwargs) trainer.fit(model, dataset) if hvd.rank() == 0: if remote_store.saving_runs and trainer.profiler: # One more file sync to push profiler result. remote_store.sync(logs_path) # rank 0 overwrites model with best checkpoint and returns. if require_checkpoint: if verbose: print("load from checkpoint best model path:", _checkpoint_callback.best_model_path) best_model = model.load_from_checkpoint( _checkpoint_callback.best_model_path) else: best_model = model serialized_checkpoint = io.BytesIO() module = best_model if not is_legacy else best_model._model output = { 'model': module.state_dict(), 'logged_metrics': trainer.logged_metrics } torch.save(output, serialized_checkpoint) return serialized_checkpoint return train
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_rows, val_rows, avg_row_size, is_legacy): # Estimator parameters input_shapes = estimator.getInputShapes() label_shapes = estimator.getLabelShapes() feature_columns = estimator.getFeatureCols() label_columns = estimator.getLabelCols() sample_weight_col = estimator.getSampleWeightCol() should_validate = estimator.getValidation() batch_size = estimator.getBatchSize() val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize( ) else batch_size epochs = estimator.getEpochs() user_shuffle_buffer_size = estimator.getShufflingBufferSize() transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() logger = estimator.getLogger() log_every_n_steps = estimator.getLogEveryNSteps() data_loader_cls = estimator.getDataLoaderClass() loader_num_epochs = estimator.getLoaderNumEpochs() verbose = (estimator.getVerbose() > 0) # Data reader parameters train_reader_worker_count = estimator.getTrainReaderNumWorker() val_reader_worker_count = estimator.getValReaderNumWorker() reader_pool_type = estimator.getReaderPoolType() # Utility functions deserialize = deserialize_fn() calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn( train_rows, avg_row_size, user_shuffle_buffer_size) schema_fields = feature_columns + label_columns if sample_weight_col: schema_fields.append(sample_weight_col) data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, inmemory_cache_all, data_loader_cls) # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) set_data_loader = _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, loader_num_epochs, store, epochs, inmemory_cache_all, verbose) def train(serialized_model): import horovod.torch as hvd # Horovod: initialize library. hvd.init() with tempfile.TemporaryDirectory( ) as last_ckpt_dir, remote_store.get_local_output_dir( ) as run_output_dir: last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt') if ckpt_bytes: with open(last_ckpt_file, 'wb') as f: f.write(ckpt_bytes) # TODO: Pass the logger from estimator constructor logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) # Use default logger if no logger is supplied train_logger = logger if train_logger is None: train_logger = TensorBoardLogger(logs_path) # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename) # os.makedirs(ckpt_path, exist_ok=True) # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path) # callbacks.append(model_checkpoint_callback) is_model_checkpoint_callback_exist = False if callbacks is not None: for cb in callbacks: if isinstance(cb, ModelCheckpoint): is_model_checkpoint_callback_exist = True break model = deserialize(serialized_model) _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \ int(math.floor(float(train_rows) / batch_size / hvd.size())) _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \ int(math.floor(float(val_rows) / val_batch_size / hvd.size())) print( f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}." ) print( f"Validation data of rank[{hvd.local_rank()}]: val_rows:{val_rows}, val_batch_size:{val_batch_size}, _val_steps_per_epoch:{_val_steps_per_epoch}, should_validate:{should_validate}" ) cuda_available = torch.cuda.is_available() # We need to check all ranks have same device type for traning. # Horovod doesn't support heterogeneous allreduce for gradients. cuda_avail_list = hvd.allgather_object(cuda_available, name='device type') if cuda_avail_list.count(cuda_available) != hvd.size(): raise RuntimeError("All ranks don't have same device type!") if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device( _get_assigned_gpu_or_default(default=hvd.local_rank())) # Move model to GPU. model.cuda() _num_gpus = num_gpus if _num_gpus is None: _num_gpus = 1 if cuda_available else 0 kwargs = { 'accelerator': 'horovod', 'gpus': _num_gpus, 'callbacks': callbacks, 'max_epochs': epochs, 'logger': train_logger, 'log_every_n_steps': log_every_n_steps, 'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None), 'checkpoint_callback': is_model_checkpoint_callback_exist, 'num_sanity_val_steps': 0, 'reload_dataloaders_every_epoch': False, 'progress_bar_refresh_rate': _train_steps_per_epoch // 10 } print("Creating trainer with: \n ", kwargs) trainer = Trainer(**kwargs) print(f"pytorch_lightning version={pl.__version__}") # print row group # pq.ParquetFile(remote_store.train_data_path) # for rowgroup in range(pq_file.metadata.num_row_groups): # row_group = pq_file.metadata.row_group(rowgroup) # print(row_group) with set_data_loader(model, remote_store.train_data_path, 'train_dataloader', train_reader_worker_count, reader_pool_type, calculate_shuffle_buffer_size(), name="train_dataloader", limit_step_per_epoch=_train_steps_per_epoch), \ set_data_loader(model, remote_store.val_data_path, 'val_dataloader', val_reader_worker_count, reader_pool_type, 0, should_validate, name="val_dataloader", limit_step_per_epoch=_val_steps_per_epoch): trainer.fit(model) serialized_checkpoint = io.BytesIO() module = model if not is_legacy else model._model # TODO: find a way to pass trainer.logged_metrics out. output = {'model': module.state_dict()} torch.save(output, serialized_checkpoint) serialized_checkpoint.seek(0) return serialized_checkpoint return train
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_rows, val_rows, avg_row_size, is_legacy): # Estimator parameters input_shapes = estimator.getInputShapes() label_shapes = estimator.getLabelShapes() feature_columns = estimator.getFeatureCols() label_columns = estimator.getLabelCols() sample_weight_col = estimator.getSampleWeightCol() should_validate = estimator.getValidation() batch_size = estimator.getBatchSize() val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize( ) else batch_size epochs = estimator.getEpochs() user_shuffle_buffer_size = estimator.getShufflingBufferSize() transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() # Data reader parameters train_reader_worker_count = estimator.getTrainReaderNumWorker() val_reader_worker_count = estimator.getValReaderNumWorker() # Utility functions deserialize = deserialize_fn() calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn( train_rows, avg_row_size, user_shuffle_buffer_size) schema_fields = feature_columns + label_columns if sample_weight_col: schema_fields.append(sample_weight_col) dataloader_cls = _create_dataloader(feature_columns, input_shapes, metadata) make_petastorm_reader = _make_petastorm_reader_fn( transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls) # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) is_dbfs = isinstance(store, DBFSLocalStore) train_steps_per_epoch = estimator.getTrainStepsPerEpoch() train_percent = train_rows / train_steps_per_epoch if train_steps_per_epoch else 1.0 val_steps_per_epoch = estimator.getValidationStepsPerEpoch() val_percent = val_rows / val_steps_per_epoch if val_steps_per_epoch else 1.0 # disable call back for now. Because petastorm can not reset index during training. callbacks = None #_make_callbacks() def train(serialized_model): with tempfile.TemporaryDirectory( ) as last_ckpt_dir, remote_store.get_local_output_dir( ) as run_output_dir: last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt') if ckpt_bytes: with open(last_ckpt_file, 'wb') as f: f.write(ckpt_bytes) logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) logger = TensorBoardLogger(logs_path) ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename) os.makedirs(ckpt_path, exist_ok=True) # disable checkpoint call back for now, waiting for the fix of # https://github.com/PyTorchLightning/pytorch-lightning/issues/6343 checkpoint_callback = None # ModelCheckpoint(dirpath=ckpt_path) model = deserialize(serialized_model) kwargs = { 'accelerator': 'horovod', 'gpus': (1 if torch.cuda.is_available() else 0), 'callbacks': callbacks, 'max_epochs': epochs, 'limit_train_batches': train_percent, 'limit_val_batches': val_percent, 'logger': logger, 'checkpoint_callback': checkpoint_callback, 'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None), 'num_sanity_val_steps': 0 } print("Creating trainer with: \n ", kwargs) trainer = Trainer(**kwargs) print(f"pytorch_lightning version={pl.__version__}") # print row group # pq.ParquetFile(remote_store.train_data_path) # for rowgroup in range(pq_file.metadata.num_row_groups): # row_group = pq_file.metadata.row_group(rowgroup) # print(row_group) with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader', train_reader_worker_count), \ make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader', val_reader_worker_count, should_validate): trainer.fit(model) serialized_checkpoint = io.BytesIO() module = model if not is_legacy else model._model torch.save({'model': module.state_dict()}, serialized_checkpoint) serialized_checkpoint.seek(0) return serialized_checkpoint return train
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_rows, val_rows, avg_row_size, is_legacy): # Estimator parameters input_shapes = estimator.getInputShapes() label_shapes = estimator.getLabelShapes() feature_columns = estimator.getFeatureCols() label_columns = estimator.getLabelCols() sample_weight_col = estimator.getSampleWeightCol() should_validate = estimator.getValidation() batch_size = estimator.getBatchSize() val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize( ) else batch_size epochs = estimator.getEpochs() user_shuffle_buffer_size = estimator.getShufflingBufferSize() transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() # Data reader parameters train_reader_worker_count = estimator.getTrainReaderNumWorker() val_reader_worker_count = estimator.getValReaderNumWorker() reader_pool_type = estimator.getReaderPoolType() # Utility functions deserialize = deserialize_fn() calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn( train_rows, avg_row_size, user_shuffle_buffer_size) schema_fields = feature_columns + label_columns if sample_weight_col: schema_fields.append(sample_weight_col) dataloader_cls = _create_dataloader(feature_columns, input_shapes, metadata) make_petastorm_reader = _make_petastorm_reader_fn( transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls) # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) def train(serialized_model): import horovod.torch as hvd # Horovod: initialize library. hvd.init() with tempfile.TemporaryDirectory( ) as last_ckpt_dir, remote_store.get_local_output_dir( ) as run_output_dir: last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt') if ckpt_bytes: with open(last_ckpt_file, 'wb') as f: f.write(ckpt_bytes) # TODO: Pass the logger from estimator constructor logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) logger = TensorBoardLogger(logs_path) # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename) # os.makedirs(ckpt_path, exist_ok=True) # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path) # callbacks.append(model_checkpoint_callback) is_model_checkpoint_callback_exist = False if callbacks is not None: for cb in callbacks: if isinstance(cb, ModelCheckpoint): is_model_checkpoint_callback_exist = True break model = deserialize(serialized_model) _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else 1.0 _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0 cuda_available = torch.cuda.is_available() if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device( _get_assigned_gpu_or_default(default=hvd.local_rank())) # Move model to GPU. model.cuda() _num_gpus = num_gpus if _num_gpus is None: _num_gpus = 1 if cuda_available else 0 kwargs = { 'accelerator': 'horovod', 'gpus': _num_gpus, 'callbacks': callbacks, 'max_epochs': epochs, 'limit_train_batches': _train_steps_per_epoch, 'limit_val_batches': _val_steps_per_epoch, 'logger': logger, 'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None), 'checkpoint_callback': is_model_checkpoint_callback_exist, 'num_sanity_val_steps': 0, 'reload_dataloaders_every_epoch': False } print("Creating trainer with: \n ", kwargs) trainer = Trainer(**kwargs) print(f"pytorch_lightning version={pl.__version__}") # print row group # pq.ParquetFile(remote_store.train_data_path) # for rowgroup in range(pq_file.metadata.num_row_groups): # row_group = pq_file.metadata.row_group(rowgroup) # print(row_group) with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader', train_reader_worker_count, reader_pool_type), \ make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader', val_reader_worker_count, reader_pool_type, should_validate): trainer.fit(model) serialized_checkpoint = io.BytesIO() module = model if not is_legacy else model._model # TODO: find a way to pass trainer.logged_metrics out. output = {'model': module.state_dict()} torch.save(output, serialized_checkpoint) serialized_checkpoint.seek(0) return serialized_checkpoint return train