コード例 #1
0
ファイル: remote.py プロジェクト: lakersdf/horovod
    def train(serialized_model):
        import horovod.torch as hvd

        if random_seed is not None:
            pl.utilities.seed.seed_everything(seed=random_seed)

        # Horovod: initialize library.
        hvd.init()

        if verbose:
            import horovod as _horovod
            print(
                f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}"
            )

        _checkpoint_callback = None
        require_checkpoint = False

        with remote_store.get_local_output_dir() as run_output_dir:
            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
            os.makedirs(logs_path, exist_ok=True)
            print(f"Made directory {logs_path} for horovod rank {hvd.rank()}")
            ckpt_dir = run_output_dir
            ckpt_filename = remote_store.checkpoint_filename

            if logger is None:
                # Use default logger if no logger is supplied
                train_logger = TensorBoardLogger(logs_path)
                print(f"Setup logger: Using TensorBoardLogger: {train_logger}")

            elif isinstance(logger, CometLogger):
                if logger._experiment_key:
                    # use logger passed in.
                    train_logger = logger
                    train_logger._save_dir = logs_path
                    print(
                        f"Setup logger: change save_dir of the logger to {logs_path}"
                    )

                elif logger_experiment_key:
                    # Resume logger experiment with new log path if key passed correctly from CPU.
                    train_logger = CometLogger(
                        save_dir=logs_path,
                        api_key=logger.api_key,
                        experiment_key=logger_experiment_key,
                    )

                    print(
                        f"Setup logger: Resume comet logger: {vars(train_logger)}"
                    )

                else:
                    print(
                        f"Failed to setup or resume comet logger. origin logger: {vars(logger)}"
                    )

            else:
                # use logger passed in.
                train_logger = logger
                train_logger.save_dir = logs_path
                print(
                    f"Setup logger: Using logger passed from estimator: {train_logger}"
                )

            # Lightning requires to add checkpoint callbacks for all ranks.
            # Otherwise we are seeing hanging in training.
            for cb in callbacks:
                if isinstance(cb, ModelCheckpoint):
                    cb.dirpath = ckpt_dir
                    cb.filename = ckpt_filename
                    _checkpoint_callback = cb
                    require_checkpoint = True
                    break
            if not _checkpoint_callback:
                # By default 'monitor'=None which saves a checkpoint only for the last epoch.
                _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
                                                       filename=ckpt_filename,
                                                       verbose=True)
                callbacks.append(_checkpoint_callback)

            if remote_store.saving_runs and hvd.rank() == 0:
                # Horovod: sync checkpoint and logging files only on rank 0 to
                # prevent other ranks from corrupting them.
                class _SyncCallback(Callback):
                    def on_epoch_end(self, trainer: "pl.Trainer",
                                     pl_module: "pl.LightningModule") -> None:
                        remote_store.sync(run_output_dir)

                callbacks.append(_SyncCallback())

            model = deserialize(serialized_model)

            _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \
                int(math.floor(float(train_rows) / batch_size / hvd.size()))

            _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
                int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

            shuffle_size = calculate_shuffle_buffer_size()
            if verbose:
                print(
                    f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}, "
                    f"Shuffle_size: {shuffle_size}, Random seed: {random_seed}\n"
                    f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n"
                    f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n"
                    f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n"
                )

            cuda_available = torch.cuda.is_available()
            # We need to check all ranks have same device type for traning.
            # Horovod doesn't support heterogeneous allreduce for gradients.
            cuda_avail_list = hvd.allgather_object(cuda_available,
                                                   name='device type')
            if cuda_avail_list.count(cuda_available) != hvd.size():
                raise RuntimeError("All ranks don't have same device type!")

            if cuda_available:
                # Horovod: pin GPU to local rank or the assigned GPU from spark.
                torch.cuda.set_device(
                    _get_assigned_gpu_or_default(default=hvd.local_rank()))
                # Move model to GPU.
                model.cuda()

            _num_gpus = num_gpus
            if _num_gpus is None:
                _num_gpus = 1 if cuda_available else 0

            # Set bar refresh to 1 / epoch, detailed loss and metrics is avaialbe in logger,
            # no need to print in screen here. User can still override this in trainer_args
            progress_bar_refresh_rate = _train_steps_per_epoch

            kwargs = {
                'accelerator': 'horovod',
                'gpus': _num_gpus,
                'callbacks': callbacks,
                'max_epochs': epochs,
                'logger': train_logger,
                'log_every_n_steps': log_every_n_steps,
                'num_sanity_val_steps': 0,
                'reload_dataloaders_every_epoch': False,
                'progress_bar_refresh_rate': progress_bar_refresh_rate,
                'terminate_on_nan': terminate_on_nan,
                'profiler': profiler
            }
            if trainer_args:
                kwargs.update(trainer_args)

            if verbose and hvd.rank() == 0:
                print("Creating trainer with: \n ", kwargs)

            trainer = Trainer(**kwargs)

            if profiler != 'simple' and trainer.profiler:
                print(
                    f"Set profiler's logs_path for {hvd.rank()} to {logs_path}"
                )
                trainer.profiler.dirpath = logs_path
                # filename where the profiler results will be saved instead of
                # printing to stdout. The .txt extension will be used automatically.
                trainer.profiler.filename = "profile"

            if verbose and hvd.rank() == 0:
                print(f"pytorch_lightning version={pl.__version__}")

            data_module_kwargs = {
                'train_dir':
                remote_store.train_data_path,
                'val_dir':
                remote_store.val_data_path,
                'num_train_epochs':
                epochs,
                'has_val':
                should_validate is not None,
                'train_batch_size':
                batch_size,
                'val_batch_size':
                val_batch_size,
                'shuffle_size':
                shuffle_size,
                'num_reader_epochs':
                loader_num_epochs,
                'reader_pool_type':
                reader_pool_type,
                'reader_worker_count':
                train_reader_worker_count,
                'transform_spec':
                transformation,
                'inmemory_cache_all':
                inmemory_cache_all,
                'cur_shard':
                hvd.rank(),
                'shard_count':
                hvd.size(),
                'schema_fields':
                schema_fields,
                'storage_options':
                storage_options,
                'steps_per_epoch_train':
                _train_steps_per_epoch,
                'steps_per_epoch_val':
                _val_steps_per_epoch,
                'verbose':
                verbose,
                'debug_data_loader':
                debug_data_loader,
                'train_async_data_loader_queue_size':
                train_async_data_loader_queue_size,
                'val_async_data_loader_queue_size':
                val_async_data_loader_queue_size,
            }
            if debug_data_loader and hvd.rank() == 0:
                print(
                    f"Creating data module with args:\n {data_module_kwargs}")

            dataset = data_module(**data_module_kwargs)

            trainer.fit(model, dataset)

            if hvd.rank() == 0:
                if remote_store.saving_runs and trainer.profiler:
                    # One more file sync to push profiler result.
                    remote_store.sync(logs_path)

                # rank 0 overwrites model with best checkpoint and returns.
                if require_checkpoint:
                    if verbose:
                        print("load from checkpoint best model path:",
                              _checkpoint_callback.best_model_path)
                    best_model = model.load_from_checkpoint(
                        _checkpoint_callback.best_model_path)
                else:
                    best_model = model
                serialized_checkpoint = io.BytesIO()
                module = best_model if not is_legacy else best_model._model

                output = {
                    'model': module.state_dict(),
                    'logged_metrics': trainer.logged_metrics
                }

                torch.save(output, serialized_checkpoint)

                return serialized_checkpoint
コード例 #2
0
ファイル: remote.py プロジェクト: raajay/horovod
    def train(serialized_model):
        import horovod.torch as hvd
        # Horovod: initialize library.
        hvd.init()

        with tempfile.TemporaryDirectory(
        ) as last_ckpt_dir, remote_store.get_local_output_dir(
        ) as run_output_dir:
            last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
            if ckpt_bytes:
                with open(last_ckpt_file, 'wb') as f:
                    f.write(ckpt_bytes)

            # TODO: Pass the logger from estimator constructor
            logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)

            # Use default logger if no logger is supplied
            train_logger = logger

            if train_logger is None:
                train_logger = TensorBoardLogger(logs_path)
            elif isinstance(train_logger,
                            CometLogger) and train_logger._save_dir is None:
                # Setting the CometLogger's save_dir allows us to sync checkpoints and profiler output
                train_logger._save_dir = logs_path

            # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
            # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
            # os.makedirs(ckpt_path, exist_ok=True)
            # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
            # callbacks.append(model_checkpoint_callback)

            is_model_checkpoint_callback_exist = False
            for cb in callbacks:
                if isinstance(cb, ModelCheckpoint):
                    is_model_checkpoint_callback_exist = True
                    break

            if remote_store.saving_runs and hvd.rank() == 0:

                class _SyncCallback(Callback):
                    def on_epoch_end(self, trainer: "pl.Trainer",
                                     pl_module: "pl.LightningModule") -> None:
                        print("Syncing to remote_store.")
                        remote_store.sync(logs_path)

                callbacks.append(_SyncCallback())

            model = deserialize(serialized_model)

            _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else \
                int(math.floor(float(train_rows) / batch_size / hvd.size()))

            _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
                int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

            print(
                f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}."
            )

            cuda_available = torch.cuda.is_available()
            # We need to check all ranks have same device type for traning.
            # Horovod doesn't support heterogeneous allreduce for gradients.
            cuda_avail_list = hvd.allgather_object(cuda_available,
                                                   name='device type')
            if cuda_avail_list.count(cuda_available) != hvd.size():
                raise RuntimeError("All ranks don't have same device type!")

            if cuda_available:
                # Horovod: pin GPU to local rank or the assigned GPU from spark.
                torch.cuda.set_device(
                    _get_assigned_gpu_or_default(default=hvd.local_rank()))
                # Move model to GPU.
                model.cuda()

            _num_gpus = num_gpus
            if _num_gpus is None:
                _num_gpus = 1 if cuda_available else 0

            kwargs = {
                'accelerator': 'horovod',
                'gpus': _num_gpus,
                'callbacks': callbacks,
                'max_epochs': epochs,
                'logger': train_logger,
                'log_every_n_steps': log_every_n_steps,
                'resume_from_checkpoint':
                (last_ckpt_file if ckpt_bytes else None),
                'checkpoint_callback': is_model_checkpoint_callback_exist,
                'num_sanity_val_steps': 0,
                'reload_dataloaders_every_epoch': False,
                'progress_bar_refresh_rate': _train_steps_per_epoch // 10,
                'terminate_on_nan': terminate_on_nan,
                'profiler': estimator.getProfiler()
            }
            print("Creating trainer with: \n ", kwargs)
            trainer = Trainer(**kwargs)

            print(f"pytorch_lightning version={pl.__version__}")

            dataset = data_module(
                train_dir=remote_store.train_data_path,
                val_dir=remote_store.val_data_path,
                num_train_epochs=epochs,
                has_val=should_validate is not None,
                train_batch_size=batch_size,
                val_batch_size=val_batch_size,
                shuffle_size=calculate_shuffle_buffer_size(),
                num_reader_epochs=loader_num_epochs,
                reader_pool_type=reader_pool_type,
                reader_worker_count=train_reader_worker_count,
                transform_spec=transformation,
                inmemory_cache_all=inmemory_cache_all,
                cur_shard=hvd.rank(),
                shard_count=hvd.size(),
                schema_fields=schema_fields,
                storage_options=storage_options,
                steps_per_epoch_train=_train_steps_per_epoch,
                steps_per_epoch_val=_val_steps_per_epoch,
                verbose=verbose)
            trainer.fit(model, dataset)

            serialized_checkpoint = io.BytesIO()
            module = model if not is_legacy else model._model

            # TODO: find a way to pass trainer.logged_metrics out.
            output = {'model': module.state_dict()}

            torch.save(output, serialized_checkpoint)
            serialized_checkpoint.seek(0)
            return serialized_checkpoint