Ejemplo n.º 1
0
 def _deserialize_dict(self, dict_values):
     deserialized_dict = dict()
     for key, val in dict_values.items():
         if val is None:
             deserialized_dict[key] = None
         elif key == EstimatorParams.model.name:
             deserialize = deserialize_fn()
             deserialized_dict[key] = deserialize(val)
         else:
             deserialized_dict[key] = codec.loads_base64(val)
     return deserialized_dict
Ejemplo n.º 2
0
    def _transform(self, df):
        import copy
        from pyspark.sql.types import StructField, StructType
        from pyspark.ml.linalg import VectorUDT

        model_pre_predict = self.getModel()
        deserialize = deserialize_fn()
        serialize = serialize_fn()
        serialized_model = serialize(model_pre_predict)

        input_shapes = self.getInputShapes()
        label_cols = self.getLabelColumns()
        output_cols = self.getOutputCols()
        feature_cols = self.getFeatureColumns()
        metadata = self._get_metadata()

        final_output_cols = util.get_output_cols(df.schema, output_cols)

        def predict(rows):
            from pyspark import Row
            from pyspark.ml.linalg import DenseVector, SparseVector

            model = deserialize(serialized_model)
            # Perform predictions.
            for row in rows:
                fields = row.asDict().copy()

                # Note: if the col is SparseVector, torch.tensor(col) correctly converts it to a
                # dense torch tensor.
                data = [
                    torch.tensor([row[col]]).reshape(shape)
                    for col, shape in zip(feature_cols, input_shapes)
                ]

                with torch.no_grad():
                    preds = model(*data)

                if not isinstance(preds, list) and not isinstance(
                        preds, tuple):
                    preds = [preds]

                for label_col, output_col, pred in zip(label_cols, output_cols,
                                                       preds):
                    meta = metadata[label_col]
                    col_type = meta['spark_data_type']
                    # dtype for dense and spark tensor is always np.float64
                    if col_type == DenseVector:
                        shape = np.prod(pred.shape)
                        flattened_pred = pred.reshape(shape, )
                        field = DenseVector(flattened_pred)
                    elif col_type == SparseVector:
                        shape = meta['shape']
                        flattened_pred = pred.reshape(shape, )
                        nonzero_indices = flattened_pred.nonzero()[0]
                        field = SparseVector(shape, nonzero_indices,
                                             flattened_pred[nonzero_indices])
                    elif pred.shape.numel() == 1:
                        # If the column is scalar type, int, float, etc.
                        value = pred.item()
                        python_type = util.spark_scalar_to_python_type(
                            col_type)
                        if issubclass(python_type, numbers.Integral):
                            value = round(value)
                        field = python_type(value)
                    else:
                        field = DenseVector(pred.reshape(-1))

                    fields[output_col] = field

                values = [fields[col] for col in final_output_cols]

                yield Row(*values)

        spark0 = SparkSession._instantiatedSession

        final_output_fields = []

        # copy input schema
        for field in df.schema.fields:
            final_output_fields.append(copy.deepcopy(field))

        # append output schema
        override_fields = df.limit(1).rdd.mapPartitions(
            predict).toDF().schema.fields[-len(output_cols):]
        for name, override, label in zip(output_cols, override_fields,
                                         label_cols):
            # default data type as label type
            data_type = metadata[label]['spark_data_type']()

            if type(override.dataType) == VectorUDT:
                # Override output to vector. This is mainly for torch's classification loss
                # where label is a scalar but model output is a vector.
                data_type = VectorUDT()
            final_output_fields.append(
                StructField(name=name, dataType=data_type, nullable=True))

        final_output_schema = StructType(final_output_fields)

        pred_rdd = df.rdd.mapPartitions(predict)

        # Use the schema from previous section to construct the final DF with prediction
        return spark0.createDataFrame(pred_rdd, schema=final_output_schema)
Ejemplo n.º 3
0
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx,
                  train_rows, val_rows, avg_row_size, is_legacy):
    # Estimator parameters
    input_shapes = estimator.getInputShapes()
    label_shapes = estimator.getLabelShapes()
    feature_columns = estimator.getFeatureCols()
    label_columns = estimator.getLabelCols()
    sample_weight_col = estimator.getSampleWeightCol()
    should_validate = estimator.getValidation()
    batch_size = estimator.getBatchSize()
    val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize(
    ) else batch_size
    epochs = estimator.getEpochs()
    random_seed = estimator.getRandomSeed()
    user_shuffle_buffer_size = estimator.getShufflingBufferSize()
    terminate_on_nan = estimator.getTerminateOnNan()
    transformation_fn = estimator.getTransformationFn()
    transformation = transformation_fn if transformation_fn else None
    inmemory_cache_all = estimator.getInMemoryCacheAll()
    callbacks = estimator.getCallbacks() or []
    train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
    val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
    num_gpus = estimator.getNumGPUs()
    data_module = estimator.getDataModule() if estimator.getDataModule(
    ) else PetastormDataModule
    loader_num_epochs = estimator.getLoaderNumEpochs()
    verbose = (estimator.getVerbose() > 0)
    trainer_args = estimator.getTrainerArgs()
    debug_data_loader = estimator.getDebugDataLoader()
    train_async_data_loader_queue_size = estimator.getTrainAsyncDataLoaderQueueSize(
    )
    val_async_data_loader_queue_size = estimator.getValAsyncDataLoaderQueueSize(
    )

    # get logger
    logger = estimator.getLogger()
    log_every_n_steps = estimator.getLogEveryNSteps()
    print(f"logger is configured: {logger}")

    # Comet logger's expriment key is not serialize correctly. Need to remember the key, and
    # resume the logger experiment from GPU instance.
    if isinstance(logger, CometLogger):
        logger_experiment_key = logger._experiment_key
        print(f"logger vars: {vars(logger)}")
    else:
        logger_experiment_key = None

    # Data reader parameters
    train_reader_worker_count = estimator.getTrainReaderNumWorker()
    val_reader_worker_count = estimator.getValReaderNumWorker()
    reader_pool_type = estimator.getReaderPoolType()

    # Utility functions
    deserialize = deserialize_fn()
    calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn(
        train_rows, avg_row_size, user_shuffle_buffer_size)

    schema_fields = feature_columns + label_columns
    if sample_weight_col:
        schema_fields.append(sample_weight_col)

    # Storage
    store = estimator.getStore()
    remote_store = store.to_remote(run_id, dataset_idx)
    storage_options = store.storage_options

    profiler = estimator.getProfiler()

    def train(serialized_model):
        import horovod.torch as hvd

        if random_seed is not None:
            pl.utilities.seed.seed_everything(seed=random_seed)

        # Horovod: initialize library.
        hvd.init()

        if verbose:
            import horovod as _horovod
            print(
                f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}"
            )

        _checkpoint_callback = None
        require_checkpoint = False

        with remote_store.get_local_output_dir() as run_output_dir:
            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
            os.makedirs(logs_path, exist_ok=True)
            print(f"Made directory {logs_path} for horovod rank {hvd.rank()}")
            ckpt_dir = run_output_dir
            ckpt_filename = remote_store.checkpoint_filename

            if logger is None:
                # Use default logger if no logger is supplied
                train_logger = TensorBoardLogger(logs_path)
                print(f"Setup logger: Using TensorBoardLogger: {train_logger}")

            elif isinstance(logger, CometLogger):
                if logger._experiment_key:
                    # use logger passed in.
                    train_logger = logger
                    train_logger._save_dir = logs_path
                    print(
                        f"Setup logger: change save_dir of the logger to {logs_path}"
                    )

                elif logger_experiment_key:
                    # Resume logger experiment with new log path if key passed correctly from CPU.
                    train_logger = CometLogger(
                        save_dir=logs_path,
                        api_key=logger.api_key,
                        experiment_key=logger_experiment_key,
                    )

                    print(
                        f"Setup logger: Resume comet logger: {vars(train_logger)}"
                    )

                else:
                    print(
                        f"Failed to setup or resume comet logger. origin logger: {vars(logger)}"
                    )

            else:
                # use logger passed in.
                train_logger = logger
                train_logger.save_dir = logs_path
                print(
                    f"Setup logger: Using logger passed from estimator: {train_logger}"
                )

            # Lightning requires to add checkpoint callbacks for all ranks.
            # Otherwise we are seeing hanging in training.
            for cb in callbacks:
                if isinstance(cb, ModelCheckpoint):
                    cb.dirpath = ckpt_dir
                    cb.filename = ckpt_filename
                    _checkpoint_callback = cb
                    require_checkpoint = True
                    break
            if not _checkpoint_callback:
                # By default 'monitor'=None which saves a checkpoint only for the last epoch.
                _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
                                                       filename=ckpt_filename,
                                                       verbose=True)
                callbacks.append(_checkpoint_callback)

            if remote_store.saving_runs and hvd.rank() == 0:
                # Horovod: sync checkpoint and logging files only on rank 0 to
                # prevent other ranks from corrupting them.
                class _SyncCallback(Callback):
                    def on_epoch_end(self, trainer: "pl.Trainer",
                                     pl_module: "pl.LightningModule") -> None:
                        remote_store.sync(run_output_dir)

                callbacks.append(_SyncCallback())

            model = deserialize(serialized_model)

            _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \
                int(math.floor(float(train_rows) / batch_size / hvd.size()))

            _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
                int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

            shuffle_size = calculate_shuffle_buffer_size()
            if verbose:
                print(
                    f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}, "
                    f"Shuffle_size: {shuffle_size}, Random seed: {random_seed}\n"
                    f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n"
                    f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n"
                    f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n"
                )

            cuda_available = torch.cuda.is_available()
            # We need to check all ranks have same device type for traning.
            # Horovod doesn't support heterogeneous allreduce for gradients.
            cuda_avail_list = hvd.allgather_object(cuda_available,
                                                   name='device type')
            if cuda_avail_list.count(cuda_available) != hvd.size():
                raise RuntimeError("All ranks don't have same device type!")

            if cuda_available:
                # Horovod: pin GPU to local rank or the assigned GPU from spark.
                torch.cuda.set_device(
                    _get_assigned_gpu_or_default(default=hvd.local_rank()))
                # Move model to GPU.
                model.cuda()

            _num_gpus = num_gpus
            if _num_gpus is None:
                _num_gpus = 1 if cuda_available else 0

            # Set bar refresh to 1 / epoch, detailed loss and metrics is avaialbe in logger,
            # no need to print in screen here. User can still override this in trainer_args
            progress_bar_refresh_rate = _train_steps_per_epoch

            kwargs = {
                'accelerator': 'horovod',
                'gpus': _num_gpus,
                'callbacks': callbacks,
                'max_epochs': epochs,
                'logger': train_logger,
                'log_every_n_steps': log_every_n_steps,
                'num_sanity_val_steps': 0,
                'reload_dataloaders_every_epoch': False,
                'progress_bar_refresh_rate': progress_bar_refresh_rate,
                'terminate_on_nan': terminate_on_nan,
                'profiler': profiler
            }
            if trainer_args:
                kwargs.update(trainer_args)

            if verbose and hvd.rank() == 0:
                print("Creating trainer with: \n ", kwargs)

            trainer = Trainer(**kwargs)

            if profiler != 'simple' and trainer.profiler:
                print(
                    f"Set profiler's logs_path for {hvd.rank()} to {logs_path}"
                )
                trainer.profiler.dirpath = logs_path
                # filename where the profiler results will be saved instead of
                # printing to stdout. The .txt extension will be used automatically.
                trainer.profiler.filename = "profile"

            if verbose and hvd.rank() == 0:
                print(f"pytorch_lightning version={pl.__version__}")

            data_module_kwargs = {
                'train_dir':
                remote_store.train_data_path,
                'val_dir':
                remote_store.val_data_path,
                'num_train_epochs':
                epochs,
                'has_val':
                should_validate is not None,
                'train_batch_size':
                batch_size,
                'val_batch_size':
                val_batch_size,
                'shuffle_size':
                shuffle_size,
                'num_reader_epochs':
                loader_num_epochs,
                'reader_pool_type':
                reader_pool_type,
                'reader_worker_count':
                train_reader_worker_count,
                'transform_spec':
                transformation,
                'inmemory_cache_all':
                inmemory_cache_all,
                'cur_shard':
                hvd.rank(),
                'shard_count':
                hvd.size(),
                'schema_fields':
                schema_fields,
                'storage_options':
                storage_options,
                'steps_per_epoch_train':
                _train_steps_per_epoch,
                'steps_per_epoch_val':
                _val_steps_per_epoch,
                'verbose':
                verbose,
                'debug_data_loader':
                debug_data_loader,
                'train_async_data_loader_queue_size':
                train_async_data_loader_queue_size,
                'val_async_data_loader_queue_size':
                val_async_data_loader_queue_size,
            }
            if debug_data_loader and hvd.rank() == 0:
                print(
                    f"Creating data module with args:\n {data_module_kwargs}")

            dataset = data_module(**data_module_kwargs)

            trainer.fit(model, dataset)

            if hvd.rank() == 0:
                if remote_store.saving_runs and trainer.profiler:
                    # One more file sync to push profiler result.
                    remote_store.sync(logs_path)

                # rank 0 overwrites model with best checkpoint and returns.
                if require_checkpoint:
                    if verbose:
                        print("load from checkpoint best model path:",
                              _checkpoint_callback.best_model_path)
                    best_model = model.load_from_checkpoint(
                        _checkpoint_callback.best_model_path)
                else:
                    best_model = model
                serialized_checkpoint = io.BytesIO()
                module = best_model if not is_legacy else best_model._model

                output = {
                    'model': module.state_dict(),
                    'logged_metrics': trainer.logged_metrics
                }

                torch.save(output, serialized_checkpoint)

                return serialized_checkpoint

    return train
Ejemplo n.º 4
0
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx,
                  train_rows, val_rows, avg_row_size, is_legacy):
    # Estimator parameters
    input_shapes = estimator.getInputShapes()
    label_shapes = estimator.getLabelShapes()
    feature_columns = estimator.getFeatureCols()
    label_columns = estimator.getLabelCols()
    sample_weight_col = estimator.getSampleWeightCol()
    should_validate = estimator.getValidation()
    batch_size = estimator.getBatchSize()
    val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize(
    ) else batch_size
    epochs = estimator.getEpochs()
    user_shuffle_buffer_size = estimator.getShufflingBufferSize()
    transformation_fn = estimator.getTransformationFn()
    transformation = transformation_fn if transformation_fn else None
    inmemory_cache_all = estimator.getInMemoryCacheAll()
    callbacks = estimator.getCallbacks()
    train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
    val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
    num_gpus = estimator.getNumGPUs()
    logger = estimator.getLogger()
    log_every_n_steps = estimator.getLogEveryNSteps()
    data_loader_cls = estimator.getDataLoaderClass()
    loader_num_epochs = estimator.getLoaderNumEpochs()
    verbose = (estimator.getVerbose() > 0)

    # Data reader parameters
    train_reader_worker_count = estimator.getTrainReaderNumWorker()
    val_reader_worker_count = estimator.getValReaderNumWorker()
    reader_pool_type = estimator.getReaderPoolType()

    # Utility functions
    deserialize = deserialize_fn()
    calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn(
        train_rows, avg_row_size, user_shuffle_buffer_size)

    schema_fields = feature_columns + label_columns
    if sample_weight_col:
        schema_fields.append(sample_weight_col)

    data_loader_cls = _create_dataloader(feature_columns, input_shapes,
                                         metadata, inmemory_cache_all,
                                         data_loader_cls)

    # Storage
    store = estimator.getStore()
    remote_store = store.to_remote(run_id, dataset_idx)

    set_data_loader = _set_data_loader_fn(transformation, schema_fields,
                                          batch_size, data_loader_cls,
                                          loader_num_epochs, store, epochs,
                                          inmemory_cache_all, verbose)

    def train(serialized_model):
        import horovod.torch as hvd
        # Horovod: initialize library.
        hvd.init()

        with tempfile.TemporaryDirectory(
        ) as last_ckpt_dir, remote_store.get_local_output_dir(
        ) as run_output_dir:
            last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
            if ckpt_bytes:
                with open(last_ckpt_file, 'wb') as f:
                    f.write(ckpt_bytes)

            # TODO: Pass the logger from estimator constructor
            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)

            # Use default logger if no logger is supplied
            train_logger = logger
            if train_logger is None:
                train_logger = TensorBoardLogger(logs_path)

            # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
            # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
            # os.makedirs(ckpt_path, exist_ok=True)
            # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
            # callbacks.append(model_checkpoint_callback)

            is_model_checkpoint_callback_exist = False
            if callbacks is not None:
                for cb in callbacks:
                    if isinstance(cb, ModelCheckpoint):
                        is_model_checkpoint_callback_exist = True
                        break

            model = deserialize(serialized_model)

            _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \
                int(math.floor(float(train_rows) / batch_size / hvd.size()))

            _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
                int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

            print(
                f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}."
            )
            print(
                f"Validation data of rank[{hvd.local_rank()}]: val_rows:{val_rows}, val_batch_size:{val_batch_size}, _val_steps_per_epoch:{_val_steps_per_epoch}, should_validate:{should_validate}"
            )

            cuda_available = torch.cuda.is_available()
            # We need to check all ranks have same device type for traning.
            # Horovod doesn't support heterogeneous allreduce for gradients.
            cuda_avail_list = hvd.allgather_object(cuda_available,
                                                   name='device type')
            if cuda_avail_list.count(cuda_available) != hvd.size():
                raise RuntimeError("All ranks don't have same device type!")

            if cuda_available:
                # Horovod: pin GPU to local rank or the assigned GPU from spark.
                torch.cuda.set_device(
                    _get_assigned_gpu_or_default(default=hvd.local_rank()))
                # Move model to GPU.
                model.cuda()

            _num_gpus = num_gpus
            if _num_gpus is None:
                _num_gpus = 1 if cuda_available else 0

            kwargs = {
                'accelerator': 'horovod',
                'gpus': _num_gpus,
                'callbacks': callbacks,
                'max_epochs': epochs,
                'logger': train_logger,
                'log_every_n_steps': log_every_n_steps,
                'resume_from_checkpoint':
                (last_ckpt_file if ckpt_bytes else None),
                'checkpoint_callback': is_model_checkpoint_callback_exist,
                'num_sanity_val_steps': 0,
                'reload_dataloaders_every_epoch': False,
                'progress_bar_refresh_rate': _train_steps_per_epoch // 10
            }
            print("Creating trainer with: \n ", kwargs)
            trainer = Trainer(**kwargs)

            print(f"pytorch_lightning version={pl.__version__}")

            # print row group
            # pq.ParquetFile(remote_store.train_data_path)
            # for rowgroup in range(pq_file.metadata.num_row_groups):
            #     row_group = pq_file.metadata.row_group(rowgroup)
            #     print(row_group)

            with set_data_loader(model, remote_store.train_data_path, 'train_dataloader',
                                 train_reader_worker_count, reader_pool_type, calculate_shuffle_buffer_size(),
                                 name="train_dataloader",
                                 limit_step_per_epoch=_train_steps_per_epoch), \
                    set_data_loader(model, remote_store.val_data_path, 'val_dataloader',
                                    val_reader_worker_count, reader_pool_type, 0,
                                    should_validate, name="val_dataloader",
                                    limit_step_per_epoch=_val_steps_per_epoch):

                trainer.fit(model)

            serialized_checkpoint = io.BytesIO()
            module = model if not is_legacy else model._model

            # TODO: find a way to pass trainer.logged_metrics out.
            output = {'model': module.state_dict()}

            torch.save(output, serialized_checkpoint)
            serialized_checkpoint.seek(0)
            return serialized_checkpoint

    return train
Ejemplo n.º 5
0
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx,
                  train_rows, val_rows, avg_row_size, is_legacy):
    # Estimator parameters
    input_shapes = estimator.getInputShapes()
    label_shapes = estimator.getLabelShapes()
    feature_columns = estimator.getFeatureCols()
    label_columns = estimator.getLabelCols()
    sample_weight_col = estimator.getSampleWeightCol()
    should_validate = estimator.getValidation()
    batch_size = estimator.getBatchSize()
    val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize(
    ) else batch_size
    epochs = estimator.getEpochs()
    user_shuffle_buffer_size = estimator.getShufflingBufferSize()
    transformation_fn = estimator.getTransformationFn()
    transformation = transformation_fn if transformation_fn else None
    inmemory_cache_all = estimator.getInMemoryCacheAll()

    # Data reader parameters
    train_reader_worker_count = estimator.getTrainReaderNumWorker()
    val_reader_worker_count = estimator.getValReaderNumWorker()

    # Utility functions
    deserialize = deserialize_fn()
    calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn(
        train_rows, avg_row_size, user_shuffle_buffer_size)

    schema_fields = feature_columns + label_columns
    if sample_weight_col:
        schema_fields.append(sample_weight_col)

    dataloader_cls = _create_dataloader(feature_columns, input_shapes,
                                        metadata)
    make_petastorm_reader = _make_petastorm_reader_fn(
        transformation, schema_fields, batch_size,
        calculate_shuffle_buffer_size, dataloader_cls)

    # Storage
    store = estimator.getStore()
    remote_store = store.to_remote(run_id, dataset_idx)
    is_dbfs = isinstance(store, DBFSLocalStore)

    train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
    train_percent = train_rows / train_steps_per_epoch if train_steps_per_epoch else 1.0

    val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
    val_percent = val_rows / val_steps_per_epoch if val_steps_per_epoch else 1.0

    # disable call back for now. Because petastorm can not reset index during training.
    callbacks = None  #_make_callbacks()

    def train(serialized_model):
        with tempfile.TemporaryDirectory(
        ) as last_ckpt_dir, remote_store.get_local_output_dir(
        ) as run_output_dir:
            last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
            if ckpt_bytes:
                with open(last_ckpt_file, 'wb') as f:
                    f.write(ckpt_bytes)

            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
            logger = TensorBoardLogger(logs_path)

            ckpt_path = os.path.join(run_output_dir,
                                     remote_store.checkpoint_filename)
            os.makedirs(ckpt_path, exist_ok=True)

            # disable checkpoint call back for now, waiting for the fix of
            # https://github.com/PyTorchLightning/pytorch-lightning/issues/6343
            checkpoint_callback = None  # ModelCheckpoint(dirpath=ckpt_path)

            model = deserialize(serialized_model)
            kwargs = {
                'accelerator': 'horovod',
                'gpus': (1 if torch.cuda.is_available() else 0),
                'callbacks': callbacks,
                'max_epochs': epochs,
                'limit_train_batches': train_percent,
                'limit_val_batches': val_percent,
                'logger': logger,
                'checkpoint_callback': checkpoint_callback,
                'resume_from_checkpoint':
                (last_ckpt_file if ckpt_bytes else None),
                'num_sanity_val_steps': 0
            }
            print("Creating trainer with: \n ", kwargs)
            trainer = Trainer(**kwargs)

            print(f"pytorch_lightning version={pl.__version__}")

            # print row group
            # pq.ParquetFile(remote_store.train_data_path)
            # for rowgroup in range(pq_file.metadata.num_row_groups):
            #     row_group = pq_file.metadata.row_group(rowgroup)
            #     print(row_group)

            with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
                                       train_reader_worker_count), \
                    make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
                                          val_reader_worker_count, should_validate):

                trainer.fit(model)

            serialized_checkpoint = io.BytesIO()
            module = model if not is_legacy else model._model
            torch.save({'model': module.state_dict()}, serialized_checkpoint)
            serialized_checkpoint.seek(0)
            return serialized_checkpoint

    return train
Ejemplo n.º 6
0
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx,
                  train_rows, val_rows, avg_row_size, is_legacy):
    # Estimator parameters
    input_shapes = estimator.getInputShapes()
    label_shapes = estimator.getLabelShapes()
    feature_columns = estimator.getFeatureCols()
    label_columns = estimator.getLabelCols()
    sample_weight_col = estimator.getSampleWeightCol()
    should_validate = estimator.getValidation()
    batch_size = estimator.getBatchSize()
    val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize(
    ) else batch_size
    epochs = estimator.getEpochs()
    user_shuffle_buffer_size = estimator.getShufflingBufferSize()
    transformation_fn = estimator.getTransformationFn()
    transformation = transformation_fn if transformation_fn else None
    inmemory_cache_all = estimator.getInMemoryCacheAll()
    callbacks = estimator.getCallbacks()
    train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
    val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
    num_gpus = estimator.getNumGPUs()

    # Data reader parameters
    train_reader_worker_count = estimator.getTrainReaderNumWorker()
    val_reader_worker_count = estimator.getValReaderNumWorker()
    reader_pool_type = estimator.getReaderPoolType()

    # Utility functions
    deserialize = deserialize_fn()
    calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn(
        train_rows, avg_row_size, user_shuffle_buffer_size)

    schema_fields = feature_columns + label_columns
    if sample_weight_col:
        schema_fields.append(sample_weight_col)

    dataloader_cls = _create_dataloader(feature_columns, input_shapes,
                                        metadata)
    make_petastorm_reader = _make_petastorm_reader_fn(
        transformation, schema_fields, batch_size,
        calculate_shuffle_buffer_size, dataloader_cls)

    # Storage
    store = estimator.getStore()
    remote_store = store.to_remote(run_id, dataset_idx)

    def train(serialized_model):
        import horovod.torch as hvd
        # Horovod: initialize library.
        hvd.init()

        with tempfile.TemporaryDirectory(
        ) as last_ckpt_dir, remote_store.get_local_output_dir(
        ) as run_output_dir:
            last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
            if ckpt_bytes:
                with open(last_ckpt_file, 'wb') as f:
                    f.write(ckpt_bytes)

            # TODO: Pass the logger from estimator constructor
            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
            logger = TensorBoardLogger(logs_path)

            # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
            # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
            # os.makedirs(ckpt_path, exist_ok=True)
            # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
            # callbacks.append(model_checkpoint_callback)

            is_model_checkpoint_callback_exist = False
            if callbacks is not None:
                for cb in callbacks:
                    if isinstance(cb, ModelCheckpoint):
                        is_model_checkpoint_callback_exist = True
                        break

            model = deserialize(serialized_model)

            _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else 1.0
            _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0

            cuda_available = torch.cuda.is_available()
            if cuda_available:
                # Horovod: pin GPU to local rank or the assigned GPU from spark.
                torch.cuda.set_device(
                    _get_assigned_gpu_or_default(default=hvd.local_rank()))
                # Move model to GPU.
                model.cuda()

            _num_gpus = num_gpus
            if _num_gpus is None:
                _num_gpus = 1 if cuda_available else 0

            kwargs = {
                'accelerator': 'horovod',
                'gpus': _num_gpus,
                'callbacks': callbacks,
                'max_epochs': epochs,
                'limit_train_batches': _train_steps_per_epoch,
                'limit_val_batches': _val_steps_per_epoch,
                'logger': logger,
                'resume_from_checkpoint':
                (last_ckpt_file if ckpt_bytes else None),
                'checkpoint_callback': is_model_checkpoint_callback_exist,
                'num_sanity_val_steps': 0,
                'reload_dataloaders_every_epoch': False
            }
            print("Creating trainer with: \n ", kwargs)
            trainer = Trainer(**kwargs)

            print(f"pytorch_lightning version={pl.__version__}")

            # print row group
            # pq.ParquetFile(remote_store.train_data_path)
            # for rowgroup in range(pq_file.metadata.num_row_groups):
            #     row_group = pq_file.metadata.row_group(rowgroup)
            #     print(row_group)

            with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
                                       train_reader_worker_count, reader_pool_type), \
                    make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
                                          val_reader_worker_count, reader_pool_type, should_validate):

                trainer.fit(model)

            serialized_checkpoint = io.BytesIO()
            module = model if not is_legacy else model._model

            # TODO: find a way to pass trainer.logged_metrics out.
            output = {'model': module.state_dict()}

            torch.save(output, serialized_checkpoint)
            serialized_checkpoint.seek(0)
            return serialized_checkpoint

    return train