Esempio n. 1
0
    def _choose_strategy(self) -> Union[Strategy, str]:
        if self._accelerator_flag == "ipu":
            return IPUStrategy.strategy_name
        if self._accelerator_flag == "hpu":
            if self._parallel_devices and len(self._parallel_devices) > 1:
                return HPUParallelStrategy.strategy_name
            else:
                return SingleHPUStrategy(device=torch.device("hpu"))
        if self._accelerator_flag == "tpu":
            if self._parallel_devices and len(self._parallel_devices) > 1:
                return TPUSpawnStrategy.strategy_name
            else:
                # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
                return SingleTPUStrategy(
                    device=self._parallel_devices[0])  # type: ignore
        if _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ
                                   or "HOROVOD_RANK" in os.environ):
            return HorovodStrategy.strategy_name
        if self._num_nodes_flag > 1:
            return DDPStrategy.strategy_name
        if len(self._parallel_devices) <= 1:
            device = (
                device_parser.determine_root_gpu_device(
                    self._parallel_devices)  # type: ignore
                if self._accelerator_flag == "gpu" else "cpu")
            # TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
            return SingleDeviceStrategy(device=device)  # type: ignore
        if len(self._parallel_devices) > 1:
            return DDPSpawnStrategy.strategy_name

        return DDPStrategy.strategy_name
    def select_strategy(self) -> Strategy:
        if isinstance(self.distributed_backend, Accelerator) and self.distributed_backend.strategy is not None:
            plugin = self.distributed_backend.strategy
        elif self.use_ddp2:
            plugin = DDP2Strategy(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
        elif self.use_ddp and self.use_deepspeed:
            plugin = DeepSpeedStrategy(
                cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
            )
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks()
            use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.detect()
            use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.detect()
            use_ddp_spawn = self._strategy_type == _StrategyType.DDP_SPAWN
            use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
            use_tpu_spawn = self.use_tpu and self._strategy_type == _StrategyType.TPU_SPAWN
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.detect()
            use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.detect()
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks()
            use_ddp_sharded = self._strategy_type == _StrategyType.DDP_SHARDED
            use_ddp_sharded_spawn = self._strategy_type == _StrategyType.DDP_SHARDED_SPAWN
            use_ddp_fully_sharded = self._strategy_type == _StrategyType.DDP_FULLY_SHARDED

            if use_tpu_spawn:
                ddp_strategy_cls = TPUSpawnStrategy
            elif use_ddp_sharded:
                ddp_strategy_cls = DDPShardedStrategy
            elif use_ddp_sharded_spawn:
                ddp_strategy_cls = DDPSpawnShardedStrategy
            elif (
                use_ddp_cpu_slurm
                or use_slurm_ddp
                or use_ddp_cpu_torch_elastic
                or use_torchelastic_ddp
                or use_kubeflow_ddp
                or use_ddp_cpu_kubeflow
            ):
                ddp_strategy_cls = DDPStrategy
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_strategy_cls = DDPSpawnStrategy
            elif use_ddp_fully_sharded:
                ddp_strategy_cls = DDPFullyShardedStrategy
            else:
                ddp_strategy_cls = DDPStrategy

            plugin = ddp_strategy_cls(
                parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment
            )
        elif self.use_dp:
            plugin = DataParallelStrategy(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodStrategy(parallel_devices=self.parallel_devices)
        elif self.use_tpu and isinstance(self.tpu_cores, list):
            plugin = SingleTPUStrategy(self.tpu_id)
        elif self.use_ipu:
            plugin = IPUStrategy(parallel_devices=self.parallel_devices)
        else:
            single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
            plugin = SingleDeviceStrategy(device=single_gpu_ordinal if self.use_gpu else "cpu")
        return plugin
Esempio n. 3
0
    def select_training_type_plugin(self) -> TrainingTypePlugin:
        if self.use_ddp2:
            plugin = DDP2Plugin(parallel_devices=self.parallel_devices,
                                cluster_environment=self.cluster_environment)
        elif self.use_ddp and self.use_deepspeed:
            plugin = DeepSpeedPlugin(
                num_nodes=self.num_nodes,
                cluster_environment=self.select_cluster_environment(),
                parallel_devices=self.parallel_devices)
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
            use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
            use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
            use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
            use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
            use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN

            # TODO: decouple from TE
            # ddp script mode uses the same flags as TE
            if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
                use_torchelastic_ddp = False

            if self.on_tpu:
                ddp_plugin_cls = TPUSpawnPlugin
            elif use_ddp_sharded:
                ddp_plugin_cls = DDPShardedPlugin
            elif use_ddp_sharded_spawn:
                ddp_plugin_cls = DDPSpawnShardedPlugin
            elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
                ddp_plugin_cls = DDPPlugin
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_plugin_cls = DDPSpawnPlugin
            else:
                ddp_plugin_cls = DDPPlugin

            plugin = ddp_plugin_cls(
                parallel_devices=self.parallel_devices,
                num_nodes=self.num_nodes,
                cluster_environment=self.cluster_environment,
                sync_batchnorm=self.sync_batchnorm,
            )
        elif self.use_dp:
            plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
        elif self.on_tpu:
            if isinstance(self.tpu_cores, list):
                plugin = SingleTPUPlugin(self.tpu_id)
            else:
                plugin = TPUSpawnPlugin(
                    parallel_devices=list(range(self.tpu_cores)))
        else:
            single_gpu_ordinal = device_parser.determine_root_gpu_device(
                self.parallel_device_ids)
            plugin = SingleDevicePlugin(device=torch.device(
                f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
        return plugin
    def _choose_strategy(self) -> Union[Strategy, str]:
        if self._accelerator_flag == "ipu":
            return IPUStrategy.strategy_name
        if self._accelerator_flag == "hpu":
            if self._parallel_devices and len(self._parallel_devices) > 1:
                return HPUParallelStrategy.strategy_name
            else:
                return SingleHPUStrategy(device=torch.device("hpu"))
        if self._accelerator_flag == "tpu":
            if self._parallel_devices and len(self._parallel_devices) > 1:
                return TPUSpawnStrategy.strategy_name
            else:
                # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
                return SingleTPUStrategy(
                    device=self._parallel_devices[0])  # type: ignore
        if _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ
                                   or "HOROVOD_RANK" in os.environ):
            return HorovodStrategy.strategy_name
        if self._num_nodes_flag > 1:
            return DDPStrategy.strategy_name
        if len(self._parallel_devices) <= 1:
            # TODO: Change this once gpu accelerator was renamed to cuda accelerator
            if isinstance(
                    self._accelerator_flag,
                (CUDAAccelerator, MPSAccelerator)) or (
                    isinstance(self._accelerator_flag, str)
                    and self._accelerator_flag in ("cuda", "gpu", "mps")):
                device = device_parser.determine_root_gpu_device(
                    self._parallel_devices)
            else:
                device = "cpu"
            # TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
            return SingleDeviceStrategy(device=device)  # type: ignore
        if len(self._parallel_devices) > 1:
            if _IS_INTERACTIVE:
                return "ddp_fork"
            return "ddp_spawn"

        return DDPStrategy.strategy_name
Esempio n. 5
0
    def select_training_type_plugin(self) -> TrainingTypePlugin:
        if isinstance(
            self.distributed_backend, Accelerator
        ) and self.distributed_backend.training_type_plugin is not None:
            plugin = self.distributed_backend.training_type_plugin
        elif self.use_ddp2:
            plugin = DDP2Plugin(
                parallel_devices=self.parallel_devices,
                cluster_environment=self.cluster_environment,
            )
        elif self.use_ddp and self.use_deepspeed:
            plugin = DeepSpeedPlugin(
                cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
            )
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
            use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
            use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
            use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
            use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
            use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
            use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
            use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
            use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
            use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED

            # TODO: decouple from TE
            # ddp script mode uses the same flags as TE
            if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
                use_torchelastic_ddp = False

            if use_tpu_spawn:
                ddp_plugin_cls = TPUSpawnPlugin
            elif use_ddp_sharded:
                ddp_plugin_cls = DDPShardedPlugin
            elif use_ddp_sharded_spawn:
                ddp_plugin_cls = DDPSpawnShardedPlugin
            elif (
                use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
                or use_kubeflow_ddp or use_ddp_cpu_kubeflow
            ):
                ddp_plugin_cls = DDPPlugin
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_plugin_cls = DDPSpawnPlugin
            elif use_ddp_fully_sharded:
                ddp_plugin_cls = DDPFullyShardedPlugin
            else:
                ddp_plugin_cls = DDPPlugin

            plugin = ddp_plugin_cls(
                parallel_devices=self.parallel_devices,
                cluster_environment=self.cluster_environment,
            )
        elif self.use_dp:
            plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
        elif self.on_tpu and isinstance(self.tpu_cores, list):
            plugin = SingleTPUPlugin(self.tpu_id)
        elif self.on_ipu:
            plugin = IPUPlugin(parallel_devices=self.parallel_devices)
        else:
            single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
            plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
        return plugin
Esempio n. 6
0
def test_determine_root_gpu_device(gpus, expected_root_gpu):
    assert device_parser.determine_root_gpu_device(gpus) == expected_root_gpu
Esempio n. 7
0
    def on_trainer_init(
        self,
        num_processes,
        tpu_cores,
        accelerator,
        distributed_backend,
        auto_select_gpus,
        gpus,
        num_nodes,
        log_gpu_memory,
        sync_batchnorm,
        benchmark,
        replace_sampler_ddp,
        deterministic,
    ):
        # temp until we remove all dist backend references
        distributed_backend = self._map_deprecated_dist_backend(
            accelerator, distributed_backend)

        self.trainer.deterministic = deterministic

        torch.backends.cudnn.deterministic = self.trainer.deterministic
        if self.trainer.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # distributed backend choice
        self.trainer.distributed_backend = distributed_backend.lower(
        ) if distributed_backend else None

        # init the default rank if exists
        # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
        # this way we only show it on rank 0
        if 'LOCAL_RANK' in os.environ:
            rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

        # benchmarking
        self.trainer.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.trainer.benchmark

        # Transfer params
        self.trainer.num_nodes = num_nodes
        self.trainer.log_gpu_memory = log_gpu_memory

        # sync-bn backend
        self.trainer.sync_batchnorm = sync_batchnorm

        self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
        self.trainer.on_tpu = self.trainer.tpu_cores is not None

        self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(
            self.trainer.tpu_cores, list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn(
                "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it."
            )
        self.trainer.num_processes = num_processes

        # override with environment flag
        gpus = os.environ.get('PL_TRAINER_GPUS', gpus)
        self.trainer.gpus = gpus

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.trainer.gpus = self.trainer.tuner.pick_multiple_gpus(gpus)

        self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(
            self.trainer.gpus)
        self.trainer.root_gpu = device_parser.determine_root_gpu_device(
            self.trainer.data_parallel_device_ids)
        self.trainer.root_device = torch.device("cpu")

        self.trainer.on_gpu = True if (
            self.trainer.data_parallel_device_ids
            and torch.cuda.is_available()) else False

        # tpu state flags
        self.trainer.use_tpu = False
        self.trainer.tpu_local_core_rank = None
        self.trainer.tpu_global_core_rank = None

        # distributed backend choice
        self.set_distributed_mode()

        # override dist backend when using tpus
        if self.trainer.on_tpu:
            self.trainer.distributed_backend = "tpu"
            self.trainer.use_tpu = True

        # init flags for SLURM+DDP to work
        self.trainer.world_size = 1
        self.trainer.interactive_ddp_procs = []

        # link up SLURM
        # TODO: this should be taken out of here... but depends too much on DDP
        self.trainer.slurm_connector.on_trainer_init(self.trainer.num_nodes)
        self.trainer.node_rank = self.determine_ddp_node_rank()
        self.trainer.local_rank = self.determine_local_rank()
        self.trainer.global_rank = 0

        # NVIDIA setup
        self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks,
                              self.trainer.data_parallel_device_ids)

        self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv(
            'KAGGLE_URL_BASE')

        self.trainer.replace_sampler_ddp = replace_sampler_ddp
Esempio n. 8
0
    def __init__(
        self,
        num_processes,
        tpu_cores,
        distributed_backend,
        auto_select_gpus,
        gpus,
        num_nodes,
        sync_batchnorm,
        benchmark,
        replace_sampler_ddp,
        deterministic,
        precision,
        amp_type,
        amp_level,
        cluster_environment,
    ):
        # initialization
        self._device_type = DeviceType.CPU
        self._distrib_type = None

        self.num_processes = num_processes
        self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
        self.distributed_backend = distributed_backend
        self.auto_select_gpus = auto_select_gpus
        self.gpus = gpus
        self.num_nodes = num_nodes
        self.sync_batchnorm = sync_batchnorm
        self.benchmark = benchmark
        self.replace_sampler_ddp = replace_sampler_ddp
        self.deterministic = deterministic
        self.precision = precision
        self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
        self.amp_level = amp_level
        self.cluster_environment = cluster_environment
        self.is_slurm_managing_tasks = False

        # init the default rank if exists
        # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
        # this way we only show it on rank 0
        if "LOCAL_RANK" in os.environ:
            rank_zero_only.rank = int(os.environ["LOCAL_RANK"])

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = pick_multiple_gpus(gpus)

        self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
        self.root_gpu = device_parser.determine_root_gpu_device(
            self.parallel_device_ids)

        self.set_distributed_mode()
        self.configure_slurm_ddp()

        self.accelerator = self.select_accelerator()

        # override dist backend when using tpus
        if self.on_tpu:
            self.distributed_backend = "tpu"
            self.use_tpu = True

        # init flags for SLURM+DDP to work
        self.world_size = 1
        self.interactive_ddp_procs = []
        self.global_rank = 0

        # NVIDIA setup
        # self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)

        # benchmarking
        # TODO: should this be moved to GPU accelerator?
        torch.backends.cudnn.benchmark = self.benchmark

        # determinism for cudnn
        # TODO: should this be moved to GPU accelerator?
        torch.backends.cudnn.deterministic = deterministic
        if deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # TODO: move this to TPU accelerator/plugin
        self.on_colab_kaggle = os.getenv("COLAB_GPU") or os.getenv(
            "KAGGLE_URL_BASE")

        self.replace_sampler_ddp = replace_sampler_ddp
Esempio n. 9
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            val_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            test_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            train_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        super().__init__()

        self.deterministic = deterministic
        torch.backends.cudnn.deterministic = self.deterministic
        if self.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # init the default rank if exists
        # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
        # this way we only show it on rank 0
        if 'LOCAL_RANK' in os.environ:
            rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

        # tracks internal state for debugging
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.lr_scheduler_connector = LRSchedulerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.initializer = Initializer(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None

        # loops
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningAccum(window_length=20)
        self.batch_idx = 0
        self.num_training_batches = 0
        self.num_val_batches = []
        self.num_sanity_val_batches = []
        self.num_test_batches = []
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # when true, prints test results
        self.verbose_test = True

        # when .test() is called, it sets this
        self.tested_ckpt_path = None

        # training state
        self.model = None
        self.datamodule = None
        self.testing = False
        self.prepare_data_per_node = prepare_data_per_node
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.interrupted = False
        self.should_stop = False
        self.running_sanity_check = False
        self._state = TrainerState.INITIALIZING

        self._default_root_dir = default_root_dir or os.getcwd()
        self._weights_save_path = weights_save_path or self._default_root_dir

        # init callbacks
        self.callbacks = callbacks or []

        # configure early stop callback
        # creates a default one if none passed in
        early_stop_callback = self.configure_early_stopping(
            early_stop_callback)
        if early_stop_callback:
            self.callbacks.append(early_stop_callback)

        # configure checkpoint callback
        # it is important that this is the last callback to run
        # pass through the required args to figure out defaults
        checkpoint_callback = self.configure_checkpoint_callback(
            checkpoint_callback)
        if checkpoint_callback:
            self.callbacks.append(checkpoint_callback)

        # TODO refactor codebase (tests) to not directly reach into these callbacks
        self.checkpoint_callback = checkpoint_callback
        self.early_stop_callback = early_stop_callback

        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.benchmark

        # Transfer params
        self.num_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory

        # sync-bn backend
        self.sync_batchnorm = sync_batchnorm

        self.gradient_clip_val = gradient_clip_val
        self.check_val_every_n_epoch = check_val_every_n_epoch

        if not isinstance(track_grad_norm,
                          (int, float)) and track_grad_norm != 'inf':
            raise MisconfigurationException(
                "track_grad_norm can be an int, a float or 'inf' (infinity norm)."
            )
        self.track_grad_norm = float(track_grad_norm)

        self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
        self.on_tpu = self.tpu_cores is not None

        self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores,
                                                      list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn(
                "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it."
            )
        self.num_processes = num_processes

        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.max_steps = max_steps
        self.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.num_sanity_val_steps = float('inf')
        else:
            self.num_sanity_val_steps = num_sanity_val_steps

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.auto_lr_find = auto_lr_find
        self.auto_scale_batch_size = auto_scale_batch_size
        self._is_data_prepared = False
        self.replace_sampler_ddp = replace_sampler_ddp

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.terminate_on_nan = terminate_on_nan
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            limit_train_batches = 1
            limit_val_batches = 1
            limit_test_batches = 1
            self.num_sanity_val_steps = 0
            self.max_epochs = 1
            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                ' val and test loop using a single batch')

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # override with environment flag
        gpus = os.environ.get('PL_TRAINER_GPUS', gpus)

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = self.tuner.pick_multiple_gpus(gpus)
        else:
            self.gpus = gpus

        self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
        self.root_gpu = device_parser.determine_root_gpu_device(
            self.data_parallel_device_ids)
        self.root_device = torch.device("cpu")

        self.on_gpu = True if (self.data_parallel_device_ids
                               and torch.cuda.is_available()) else False

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend)

        # override dist backend when using tpus
        if self.on_tpu:
            self.distributed_backend = 'tpu'
            self.init_tpu()

        # init flags for SLURM+DDP to work
        self.world_size = 1
        self.interactive_ddp_procs = []
        self.configure_slurm_ddp(self.num_nodes)
        self.node_rank = self.determine_ddp_node_rank()
        self.local_rank = self.determine_local_rank()
        self.global_rank = 0

        # NVIDIA setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        self._progress_bar_callback = self.configure_progress_bar(
            progress_bar_refresh_rate, process_position)

        # logging
        self.configure_logger(logger)
        self.log_save_interval = log_save_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        # TODO: remove in 0.10.0
        if overfit_pct is not None:
            rank_zero_warn(
                "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            overfit_batches = overfit_pct

        # TODO: remove in 0.10.0
        if val_percent_check is not None:
            rank_zero_warn(
                "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_val_batches = val_percent_check

        # TODO: remove in 0.10.0
        if test_percent_check is not None:
            rank_zero_warn(
                "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_test_batches = test_percent_check

        # TODO: remove in 0.10.0
        if train_percent_check is not None:
            rank_zero_warn(
                "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_train_batches = train_percent_check

        self.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.overfit_batches = _determine_batch_limits(overfit_batches,
                                                       'overfit_batches')
        self.determine_data_use_amount(self.overfit_batches)

        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.autocast_original_forward = None
        self.precision = precision
        self.scaler = None

        self.amp_level = amp_level
        self.initializer.init_amp(amp_backend)

        self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv(
            'KAGGLE_URL_BASE')

        # Callback system
        self.on_init_end()