Ejemplo n.º 1
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,
            callbacks: List[Callback] = [],
            default_save_path: Optional[str] = None,
            gradient_clip_val: float = 0,
            gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
            process_position: int = 0,
            nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
            num_nodes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            num_tpu_cores: Optional[int] = None,
            log_gpu_memory: Optional[str] = None,
            show_progress_bar: bool = True,
            progress_bar_refresh_rate: int = 50,
            overfit_pct: float = 0.0,
            track_grad_norm: int = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
            max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            train_percent_check: float = 1.0,
            val_percent_check: float = 1.0,
            test_percent_check: float = 1.0,
            val_check_interval: Union[float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 10,
            add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
            distributed_backend: Optional[str] = None,
            use_amp=False,  # backward compatible, todo: remove in v0.8.0
            precision: int = 32,
            print_nan_grads: bool = False,
            weights_summary: str = 'full',
            weights_save_path: Optional[str] = None,
            amp_level: str = 'O1',
            nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
            num_sanity_val_steps: int = 5,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[BaseProfiler] = None,
            benchmark: bool = False,
            reload_dataloaders_every_epoch: bool = False,
    ):
        r"""

        Customize every aspect of training via flags

        Args:
            logger: Logger (or iterable collection of loggers) for experiment tracking.
                Example::

                    from pytorch_lightning.loggers import TensorBoardLogger

                    # default logger used by trainer
                    logger = TensorBoardLogger(
                        save_dir=os.getcwd(),
                        version=self.slurm_job_id,
                        name='lightning_logs'
                    )

                    Trainer(logger=logger)

            checkpoint_callback: Callback for checkpointing.
                Example::

                    from pytorch_lightning.callbacks import ModelCheckpoint

                    # default used by the Trainer
                    checkpoint_callback = ModelCheckpoint(
                        filepath=os.getcwd(),
                        save_best_only=True,
                        verbose=True,
                        monitor='val_loss',
                        mode='min',
                        prefix=''
                    )

                    trainer = Trainer(checkpoint_callback=checkpoint_callback)

            early_stop_callback: Callback for early stopping. If
                set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
                Will raise an error if ``'val_loss'`` is not found.
                If set to ``False``, then early stopping will be disabled.
                If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
                If ``'val_loss'`` is not found will work as if early stopping is disabled.
                Default: ``None``.
                Example::

                    from pytorch_lightning.callbacks import EarlyStopping

                    # default used by the Trainer
                    early_stop_callback = EarlyStopping(
                        monitor='val_loss',
                        patience=3,
                        strict=False,
                        verbose=False,
                        mode='min'
                    )

                    trainer = Trainer(early_stop_callback=early_stop_callback)

            callbacks: Add a list of callbacks.
                Example::
                    from pytorch_lightning.callbacks import Callback
                    class PrintCallback(Callback):
                        def on_train_start(self):
                            print("Training is started!")
                        def on_train_end(self):
                            print(f"Training is done. The logs are: {self.trainer.logs}")
                    # a list of callbacks
                    callbacks = [PrintCallback()]
                    trainer = Trainer(callbacks=callbacks)

            default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
                Example::

                    # default used by the Trainer
                    trainer = Trainer(default_save_path=os.getcwd())

            gradient_clip_val: 0 means don't clip.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(gradient_clip_val=0.0)

            gradient_clip:
                .. warning: .. deprecated:: 0.5.0
                    Use `gradient_clip_val` instead. Will remove 0.8.0.

            process_position: orders the tqdm bar when running multiple models on same machine.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(process_position=0)

            num_nodes: number of GPU nodes for distributed training.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(num_nodes=1)

                    # to train on 8 nodes
                    trainer = Trainer(num_nodes=8)

            nb_gpu_nodes:
                ..warning:: .. deprecated:: 0.5.0
                    Use `num_nodes` instead. Will remove 0.8.0.

            gpus: Which GPUs to train on.
                Example::

                    # default used by the Trainer (ie: train on CPU)
                    trainer = Trainer(gpus=None)

                    # int: train on 2 gpus
                    trainer = Trainer(gpus=2)

                    # list: train on GPUs 1, 4 (by bus ordering)
                    trainer = Trainer(gpus=[1, 4])
                    trainer = Trainer(gpus='1, 4') # equivalent

                    # -1: train on all gpus
                    trainer = Trainer(gpus=-1)
                    trainer = Trainer(gpus='-1') # equivalent

                    # combine with num_nodes to train on multiple GPUs across nodes
                    trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total

            num_tpu_cores: How many TPU cores to train on (1 or 8).
                A single TPU v2 or v3 has 8 cores. A TPU pod has
                up to 2048 cores. A slice of a POD means you get as many cores
                as you request.

                You MUST use DistributedDataSampler with your dataloader for this
                to work. Your effective batch size is batch_size * total tpu cores.

                This parameter can be either 1 or 8.

                Example::

                    # your_trainer_file.py

                    # default used by the Trainer (ie: train on CPU)
                    trainer = Trainer(num_tpu_cores=None)

                    # int: train on a single core
                    trainer = Trainer(num_tpu_cores=1)

                    # int: train on all cores few cores
                    trainer = Trainer(num_tpu_cores=8)

                    # for 8+ cores must submit via xla script with
                    # a max of 8 cores specified. The XLA script
                    # will duplicate script onto each TPU in the POD
                    trainer = Trainer(num_tpu_cores=8)

                    # -1: train on all available TPUs
                    trainer = Trainer(num_tpu_cores=-1)

            To train on more than 8 cores (ie: a POD),
            submit this script using the xla_dist script.

            Example::

                $ python -m torch_xla.distributed.xla_dist
                --tpu=$TPU_POD_NAME
                --conda-env=torch-xla-nightly
                --env=XLA_USE_BF16=1
                -- python your_trainer_file.py

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance
                because it uses the output of nvidia-smi.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(log_gpu_memory=None)

                    # log all the GPUs (on master node only)
                    trainer = Trainer(log_gpu_memory='all')

                    # log only the min and max memory on the master node
                    trainer = Trainer(log_gpu_memory='min_max')

            show_progress_bar: If true shows tqdm progress bar
                Example::

                    # default used by the Trainer
                    trainer = Trainer(show_progress_bar=True)

            progress_bar_refresh_rate: How often to refresh progress bar (in steps)

            overfit_pct: uses this much data of all datasets.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(overfit_pct=0.0)

                    # use only 1% of the train, test, val datasets
                    trainer = Trainer(overfit_pct=0.01)

            track_grad_norm: -1 no tracking. Otherwise tracks that norm
                Example::

                    # default used by the Trainer
                    trainer = Trainer(track_grad_norm=-1)

                    # track the 2-norm
                    trainer = Trainer(track_grad_norm=2)

            check_val_every_n_epoch: Check val every n train epochs.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(check_val_every_n_epoch=1)

                    # run val loop every 10 training epochs
                    trainer = Trainer(check_val_every_n_epoch=10)

            fast_dev_run: runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).
                Example::

                    # default used by the Trainer
                    trainer = Trainer(fast_dev_run=False)

                    # runs 1 train, val, test  batch and program ends
                    trainer = Trainer(fast_dev_run=True)

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
                Example::

                    # default used by the Trainer (no accumulation)
                    trainer = Trainer(accumulate_grad_batches=1)

                    # accumulate every 4 batches (effective batch size is batch*4)
                    trainer = Trainer(accumulate_grad_batches=4)

                    # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
                    trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})

            max_epochs: Stop training once this number of epochs is reached.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(max_epochs=1000)

            max_nb_epochs:
                .. warning:: .. deprecated:: 0.5.0
                    Use `max_epochs` instead. Will remove 0.8.0.

            min_epochs: Force training for at least these many epochs
                Example::

                    # default used by the Trainer
                    trainer = Trainer(min_epochs=1)

            min_nb_epochs:
                .. warning:: .. deprecated:: 0.5.0
                    Use `min_nb_epochs` instead. Will remove 0.8.0.

            max_steps: Stop training after this number of steps. Disabled by default (None).
                Training will stop if max_steps or max_epochs have reached (earliest).
                Example::

                    # Stop after 100 steps
                    trainer = Trainer(max_steps=100)

            min_steps: Force training for at least these number of steps. Disabled by default (None).
                Trainer will train model for at least min_steps or min_epochs (latest).
                Example::

                    # Run at least for 100 steps (disable min_epochs)
                    trainer = Trainer(min_steps=100, min_epochs=0)

            train_percent_check: How much of training dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(train_percent_check=1.0)

                    # run through only 25% of the training set each epoch
                    trainer = Trainer(train_percent_check=0.25)

            val_percent_check: How much of validation dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(val_percent_check=1.0)

                    # run through only 25% of the validation set each epoch
                    trainer = Trainer(val_percent_check=0.25)

            test_percent_check: How much of test dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(test_percent_check=1.0)

                    # run through only 25% of the test set each epoch
                    trainer = Trainer(test_percent_check=0.25)

            val_check_interval: How often within one training epoch to check the validation set
                If float, % of tng epoch. If int, check every n batch
                Example::

                    # default used by the Trainer
                    trainer = Trainer(val_check_interval=1.0)

                    # check validation set 4 times during a training epoch
                    trainer = Trainer(val_check_interval=0.25)

                    # check validation set every 1000 training batches
                    # use this when using iterableDataset and your dataset has no length
                    # (ie: production cases with streaming data)
                    trainer = Trainer(val_check_interval=1000)

            log_save_interval: Writes logs to disk this often
                Example::

                    # default used by the Trainer
                    trainer = Trainer(log_save_interval=100)

            row_log_interval: How often to add logging rows (does not write to disk)
                Example::

                    # default used by the Trainer
                    trainer = Trainer(row_log_interval=10)

            add_row_log_interval:
                .. warning:: .. deprecated:: 0.5.0
                    Use `row_log_interval` instead. Will remove 0.8.0.

            distributed_backend: The distributed backend to use.
                Options: 'dp', 'ddp', 'ddp2'.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(distributed_backend=None)

                    # dp = DataParallel (split a batch onto k gpus on same machine).
                    trainer = Trainer(gpus=2, distributed_backend='dp')

                    # ddp = DistributedDataParallel
                    # Each gpu trains by itself on a subset of the data.
                    # Gradients sync across all gpus and all machines.
                    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')

                    # ddp2 = DistributedDataParallel + dp
                    # behaves like dp on every node
                    # syncs gradients across nodes like ddp
                    # useful for things like increasing the number of negative samples
                    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')

            use_amp:
                .. warning:: .. deprecated:: 0.6.1
                    Use `precision` instead. Will remove 0.8.0.

            precision: Full precision (32), half precision (16).
                Can be used on CPU, GPU or TPUs.

                If used on TPU will use torch.bfloat16 but tensor printing
                will still show torch.float32.

                Example::

                    # default used by the Trainer
                    trainer = Trainer(precision=32)

                    # 16-bit precision
                    trainer = Trainer(precision=16)

                    # one day
                    trainer = Trainer(precision=8|4|2)

            print_nan_grads: Prints gradients with nan values
                Example::

                    # default used by the Trainer
                    trainer = Trainer(print_nan_grads=False)

            weights_summary: Prints a summary of the weights when training begins.
                Options: 'full', 'top', None.
                Example::

                    # default used by the Trainer (ie: print all weights)
                    trainer = Trainer(weights_summary='full')

                    # print only the top level modules
                    trainer = Trainer(weights_summary='top')

                    # don't print a summary
                    trainer = Trainer(weights_summary=None)

            weights_save_path: Where to save weights if specified.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(weights_save_path=os.getcwd())

                    # save to your custom path
                    trainer = Trainer(weights_save_path='my/path')

                    # if checkpoint callback used, then overrides the weights path
                    # **NOTE: this saves weights to some/path NOT my/path
                    checkpoint_callback = ModelCheckpoint(filepath='some/path')
                    trainer = Trainer(
                        checkpoint_callback=checkpoint_callback,
                        weights_save_path='my/path'
                    )

            amp_level: The optimization level to use (O1, O2, etc...).
                Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
                Example::

                    # default used by the Trainer
                    trainer = Trainer(amp_level='O1')

            num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
                This catches any bugs in your validation without having to wait for the first validation check.
                The Trainer uses 5 steps by default. Turn it off or modify it here.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(num_sanity_val_steps=5)

                    # turn it off
                    trainer = Trainer(num_sanity_val_steps=0)

            nb_sanity_val_steps:
                .. warning:: .. deprecated:: 0.5.0
                    Use `num_sanity_val_steps` instead. Will remove 0.8.0.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
                a much longer sequence If this is enabled, your batches will automatically get truncated
                and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
                dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
                recurrent network trajectories."
                <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
                Example::

                    # default used by the Trainer (ie: disabled)
                    trainer = Trainer(truncated_bptt_steps=None)

                    # backprop every 5 steps in a batch
                    trainer = Trainer(truncated_bptt_steps=5)


                Lightning takes care to split your batch along the time-dimension.

                .. note:: If you need to modify how the batch is split,
                    override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

                .. note:: Using this feature requires updating your LightningModule's
                    :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k
                Example::

                    # default used by the Trainer
                    trainer = Trainer(resume_from_checkpoint=None)

                    # resume from a specific checkpoint
                    trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
            profiler:  To profile individual steps during training and assist in
                identifying bottlenecks.
                Example::

                    from pytorch_lightning.profiler import Profiler, AdvancedProfiler

                    # default used by the Trainer
                    trainer = Trainer(profiler=None)

                    # to profile standard training events
                    trainer = Trainer(profiler=True)

                    # equivalent to profiler=True
                    profiler = Profiler()
                    trainer = Trainer(profiler=profiler)

                    # advanced profiler for function-level stats
                    profiler = AdvancedProfiler()
                    trainer = Trainer(profiler=profiler)
            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch

            benchmark (bool): If true enables cudnn.benchmark.
                This flag is likely to increase the speed of your system if your
                input sizes don't change. However, if it does, then it will likely
                make your system slower.

                The speedup comes from allowing the cudnn auto-tuner to find the best
                algorithm for the hardware `[see discussion here]
                <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.

        .. warning:: Following arguments become deprecated and they will be removed in v0.8.0:

            - `nb_sanity_val_steps`

        """

        # Init callbacks
        self.callbacks = callbacks
        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        if benchmark:
            torch.backends.cudnn.benchmark = True

        # Transfer params
        # Backward compatibility
        if nb_gpu_nodes is not None:
            warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not num_nodes:  # in case you did not set the proper value
                num_nodes = nb_gpu_nodes
        self.num_gpu_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory

        # Backward compatibility
        if gradient_clip is not None:
            warnings.warn("`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not gradient_clip_val:  # in case you did not set the proper value
                gradient_clip_val = gradient_clip
        self.gradient_clip_val = gradient_clip_val

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
        self.progress_bar_refresh_rate = progress_bar_refresh_rate
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

        # tpu config
        self.on_tpu = num_tpu_cores is not None
        self.num_tpu_cores = num_tpu_cores
        assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8'

        self.process_position = process_position
        self.weights_summary = weights_summary

        # Backward compatibility
        if max_nb_epochs is not None:
            warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not max_epochs:  # in case you did not set the proper value
                max_epochs = max_nb_epochs
        self.max_epochs = max_epochs

        # Backward compatibility
        if min_nb_epochs is not None:
            warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not min_epochs:  # in case you did not set the proper value
                min_epochs = min_nb_epochs
        self.min_epochs = min_epochs

        self.max_steps = max_steps
        self.min_steps = min_steps

        # Backward compatibility
        if nb_sanity_val_steps is not None:
            warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not num_sanity_val_steps:  # in case you did not set the proper value
                num_sanity_val_steps = nb_sanity_val_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        self.print_nan_grads = print_nan_grads
        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 1
            self.max_epochs = 1
            m = '''
            Running in fast_dev_run mode: will run a full train,
            val loop using a single batch
            '''
            log.info(m)

        # set default save path if user didn't provide one
        self.default_save_path = default_save_path
        if self.default_save_path is None:
            self.default_save_path = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = []
        self.avg_loss = 0
        self.batch_idx = 0
        self.tqdm_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.global_step = 0
        self.current_epoch = 0
        self.total_batches = 0

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = Profiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        self.reduce_lr_on_plateau_scheduler = None

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # allow int, string and gpu list
        self.data_parallel_device_ids = parse_gpu_ids(gpus)
        self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.use_ddp = False
        self.use_ddp2 = False
        self.use_dp = False
        self.single_gpu = False
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend, num_nodes)

        # override dist backend when using tpus
        if self.on_tpu:
            self.init_tpu()
            self.current_tpu_idx = None

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.node_rank = 0
        self.configure_slurm_ddp(num_nodes)

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)

        # can't init progress bar here because starting a new process
        # means the progress_bar won't survive pickling
        self.show_progress_bar = show_progress_bar

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        # backward compatibility
        if add_row_log_interval is not None:
            warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # 16 bit mixed precision training using apex
        self.amp_level = amp_level
        self.precision = precision
        if self.precision == 16:
            use_amp = True
        self.init_amp(use_amp)

        # Callback system
        self.on_init_end()
Ejemplo n.º 2
0
    def __init__(
        self,
        logger=True,
        checkpoint_callback=True,
        early_stop_callback=True,
        default_save_path=None,
        gradient_clip_val=0,
        gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
        process_position=0,
        nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
        num_nodes=1,
        gpus=None,
        log_gpu_memory=None,
        show_progress_bar=True,
        overfit_pct=0.0,
        track_grad_norm=-1,
        check_val_every_n_epoch=1,
        fast_dev_run=False,
        accumulate_grad_batches=1,
        max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
        min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
        max_num_epochs=1000,
        min_num_epochs=1,
        train_percent_check=1.0,
        val_percent_check=1.0,
        test_percent_check=1.0,
        val_check_interval=1.0,
        log_save_interval=100,
        row_log_interval=10,
        add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
        distributed_backend=None,
        use_amp=False,
        print_nan_grads=False,
        weights_summary='full',
        weights_save_path=None,
        amp_level='O1',
        nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
        num_sanity_val_steps=5,
        truncated_bptt_steps=None,
        resume_from_checkpoint=None,
    ):
        """

        :param logger: Logger for experiment tracking
        :param checkpoint_callback: Callback for checkpointing
        :param early_stop_callback: Callback for early stopping
        :param str default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
        :param int gradient_clip_val: 0 means don't clip.
        :param int gradient_clip: 0 means don't clip. Deprecated.
        :param process_position: shown in the tqdm bar
        :param int num_nodes: number of GPU nodes
        :param list|str|int gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] OR '0,1'
            OR '-1' / -1 to use all available gpus
        :param str log_gpu_memory: None, 'min_max', 'all'
        :param bool show_progress_bar: If true shows tqdm bar
        :param float overfit_pct: uses this much of all datasets
        :param int track_grad_norm: -1 no tracking. Otherwise tracks that norm
        :param int check_val_every_n_epoch: check val every n train epochs
        :param bool fast_dev_run: runs full iteration over everything to find bugs
        :param int accumulate_grad_batches: Accumulates grads every k batches
        :param int max_num_epochs:
        :param int min_num_epochs:
        :param int train_percent_check: How much of train set to check
        :param int val_percent_check: How much of val set to check
        :param int test_percent_check: How much of test set to check
        :param float|int val_check_interval: If float, % of tng epoch. If int, check every n batch
        :param int log_save_interval: Writes logs to disk this often
        :param int row_log_interval: How often to add logging rows
        :param int add_row_log_interval: How often to add logging rows. Deprecated.
        :param str distributed_backend: Options: 'dp', 'ddp', 'ddp2'.
        :param bool use_amp: If true uses apex for 16bit precision
        :param bool print_nan_grads: Prints nan gradients
        :param str weights_summary: Options: 'full', 'top', None to not print.
        :param bool weights_save_path: Where to save weights if on cluster
        :param str amp_level: Check nvidia docs for level
        :param int num_sanity_val_steps: How many val steps before a full train loop.
        :param int truncated_bptt_steps: Enables multiple backward passes for each batch.

        .. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
            - `gradient_clip`,
            - `nb_gpu_nodes`,
            - `max_nb_epochs`,
            - `min_nb_epochs`,
            - `add_row_log_interval`,
            - `nb_sanity_val_steps`

        """
        # Transfer params
        if nb_gpu_nodes is not None:  # Backward compatibility
            warnings.warn(
                "`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not num_nodes:  # in case you did not set the proper value
                num_nodes = nb_gpu_nodes
        self.num_gpu_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory
        if gradient_clip is not None:  # Backward compatibility
            warnings.warn(
                "`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not gradient_clip_val:  # in case you did not set the proper value
                gradient_clip_val = gradient_clip
        self.gradient_clip_val = gradient_clip_val
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
        self.process_position = process_position
        self.weights_summary = weights_summary
        if max_nb_epochs is not None:  # Backward compatibility
            warnings.warn(
                "`max_nb_epochs` has renamed to `max_num_epochs` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not max_num_epochs:  # in case you did not set the proper value
                max_num_epochs = max_nb_epochs
        self.max_num_epochs = max_num_epochs
        if min_nb_epochs is not None:  # Backward compatibility
            warnings.warn(
                "`min_nb_epochs` has renamed to `min_num_epochs` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not min_num_epochs:  # in case you did not set the proper value
                min_num_epochs = min_nb_epochs
        self.min_num_epochs = min_num_epochs
        if nb_sanity_val_steps is not None:  # Backward compatibility
            warnings.warn(
                "`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not num_sanity_val_steps:  # in case you did not set the proper value
                num_sanity_val_steps = nb_sanity_val_steps
        self.num_sanity_val_steps = num_sanity_val_steps
        self.print_nan_grads = print_nan_grads
        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 1
            self.max_num_epochs = 1
            m = '''
            Running in fast_dev_run mode: will run a full train,
            val loop using a single batch
            '''
            logging.info(m)

        # set default save path if user didn't provide one
        self.default_save_path = default_save_path
        if self.default_save_path is None:
            self.default_save_path = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = []
        self.avg_loss = 0
        self.batch_idx = 0
        self.tqdm_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.get_train_dataloader = None
        self.get_test_dataloaders = None
        self.get_val_dataloaders = None
        self.is_iterable_train_dataloader = False

        # training state
        self.model = None
        self.testing = False
        self.lr_schedulers = []
        self.optimizers = None
        self.global_step = 0
        self.current_epoch = 0
        self.total_batches = 0

        # configure early stop callback
        # creates a default one if none passed in
        self.early_stop_callback = None
        self.configure_early_stopping(early_stop_callback, logger)

        self.reduce_lr_on_plateau_scheduler = None

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # allow int, string and gpu list
        self.data_parallel_device_ids = parse_gpu_ids(gpus)
        self.root_gpu = determine_root_gpu_device(
            self.data_parallel_device_ids)

        # distributed backend choice
        self.use_ddp = False
        self.use_ddp2 = False
        self.use_dp = False
        self.single_gpu = False
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend, num_nodes)

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.node_rank = 0
        self.configure_slurm_ddp(num_nodes)

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        # can't init progress bar here because starting a new process
        # means the progress_bar won't survive pickling
        self.show_progress_bar = show_progress_bar

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval
        if add_row_log_interval is not None:
            # backward compatibility
            warnings.warn(
                "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # 16 bit mixed precision training using apex
        self.amp_level = amp_level
        self.init_amp(use_amp)
Ejemplo n.º 3
0
def test_parse_gpu_fail_on_non_existent_id_2(mocked_device_count):
    with pytest.raises(MisconfigurationException):
        parse_gpu_ids([1, 2, 19])
Ejemplo n.º 4
0
def test_parse_gpu_returns_None_when_no_devices_are_available(
        mocked_device_count_0, gpus):
    with pytest.raises(MisconfigurationException):
        parse_gpu_ids(gpus)
Ejemplo n.º 5
0
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
    with pytest.raises(MisconfigurationException):
        parse_gpu_ids(gpus)
Ejemplo n.º 6
0
def test_parse_gpu_fail_on_non_existent_id(mocked_device_count_0, gpus):
    with pytest.raises(MisconfigurationException):
        parse_gpu_ids(gpus)
Ejemplo n.º 7
0
def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
    assert parse_gpu_ids(gpus) == expected_gpu_ids
Ejemplo n.º 8
0
def test_parse_gpu_fail_on_empty_string(mocked_device_count, gpus):
    # This currently results in a ValueError instead of MisconfigurationException
    with pytest.raises(ValueError):
        parse_gpu_ids(gpus)
Ejemplo n.º 9
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_pct: float = 0.0,
            track_grad_norm: int = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            train_percent_check: float = 1.0,
            val_percent_check: float = 1.0,
            test_percent_check: float = 1.0,
            val_check_interval: float = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 10,
            add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
            distributed_backend: Optional[str] = None,
            precision: int = 32,
            print_nan_grads:
        bool = False,  # backward compatible, todo: remove in v0.9.0
            weights_summary: Optional[str] = 'full',
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            num_tpu_cores: Optional[
                int] = None,  # backward compatible, todo: remove in v0.9.0
            amp_level: str = 'O1',  # backward compatible, todo: remove in v0.8.0
            default_save_path=None,  # backward compatible, todo: remove in v0.8.0
            gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
            nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
            max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            use_amp=None,  # backward compatible, todo: remove in v0.9.0
            show_progress_bar=None,  # backward compatible, todo: remove in v0.9.0
            nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
    ):
        r"""

        Customize every aspect of training via flags

        Args:
            logger: Logger (or iterable collection of loggers) for experiment tracking.

            checkpoint_callback: Callback for checkpointing.

            early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):

            callbacks: Add a list of callbacks.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed

            default_save_path:
                .. warning:: .. deprecated:: 0.7.3

                    Use `default_root_dir` instead. Will remove 0.9.0.

            gradient_clip_val: 0 means don't clip.

            gradient_clip:
                .. warning:: .. deprecated:: 0.7.0

                    Use `gradient_clip_val` instead. Will remove 0.9.0.

            process_position: orders the progress bar when running multiple models on same machine.

            num_nodes: number of GPU nodes for distributed training.

            nb_gpu_nodes:
                .. warning:: .. deprecated:: 0.7.0

                    Use `num_nodes` instead. Will remove 0.9.0.

            gpus: Which GPUs to train on.

            auto_select_gpus:

                If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            num_tpu_cores: How many TPU cores to train on (1 or 8)
                .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            show_progress_bar:
                .. warning:: .. deprecated:: 0.7.2

                        Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.

            overfit_pct: How much of training-, validation-, and test dataset to check.

            track_grad_norm: -1 no tracking. Otherwise tracks that norm

            check_val_every_n_epoch: Check val every n train epochs.

            fast_dev_run: runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            max_epochs: Stop training once this number of epochs is reached.

            max_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `max_epochs` instead. Will remove 0.9.0.

            min_epochs: Force training for at least these many epochs

            min_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `min_epochs` instead. Will remove 0.9.0.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            train_percent_check: How much of training dataset to check.

            val_percent_check: How much of validation dataset to check.

            test_percent_check: How much of test dataset to check.

            val_check_interval: How often within one training epoch to check the validation set

            log_save_interval: Writes logs to disk this often

            row_log_interval: How often to add logging rows (does not write to disk)

            add_row_log_interval:
                .. warning:: .. deprecated:: 0.7.0

                    Use `row_log_interval` instead. Will remove 0.9.0.

            distributed_backend: The distributed backend to use.

            use_amp:
                .. warning:: .. deprecated:: 0.7.0

                    Use `precision` instead. Will remove 0.9.0.

            precision: Full precision (32), half precision (16).

            print_nan_grads:
                .. warning:: .. deprecated:: 0.7.2

                    Has no effect. When detected, NaN grads will be printed automatically.
                    Will remove 0.9.0.

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                    for checkpoints only. Use this if for whatever reason you need the checkpoints
                    stored in a different place than the logs written in `default_root_dir`.

            amp_level: The optimization level to use (O1, O2, etc...).

            num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.

            nb_sanity_val_steps:
                .. warning:: .. deprecated:: 0.7.0

                    Use `num_sanity_val_steps` instead. Will remove 0.8.0.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.

            profiler:  To profile individual steps during training and assist in

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch

            auto_lr_find: If set to True, will `initially` run a learning rate finder,
                trying to optimize initial learning for faster convergence. Sets learning
                rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key, set a string instead of True with the key name.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement.
                If not specified this will toggled automatically ddp is used

            benchmark: If true enables cudnn.benchmark.

            deterministic: If true enables cudnn.deterministic

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.
        """
        super().__init__()

        self.deterministic = deterministic
        torch.backends.cudnn.deterministic = self.deterministic
        if self.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # Init callbacks
        self.callbacks = callbacks or []
        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.benchmark

        # Transfer params
        self.num_nodes = num_nodes
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_gpu_nodes is not None:
            rank_zero_warn(
                "Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.num_gpu_nodes = nb_gpu_nodes
        self.log_gpu_memory = log_gpu_memory

        self.gradient_clip_val = gradient_clip_val
        # Backward compatibility, TODO: remove in v0.8.0
        if gradient_clip is not None:
            rank_zero_warn(
                "Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.gradient_clip = gradient_clip

        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

        # tpu config
        if num_tpu_cores is not None:
            rank_zero_warn(
                "Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
                " and this argument will be removed in v0.9.0",
                DeprecationWarning)

        if tpu_cores is None:
            tpu_cores = num_tpu_cores
        self.on_tpu = tpu_cores is not None
        self.tpu_cores = tpu_cores
        assert self.tpu_cores in (1, 8, None) or (
            isinstance(self.tpu_cores,
                       (list, tuple, set)) and len(self.tpu_cores)
            == 1), '`tpu_cores` can only be 1, 8 or [<1-8>]'

        self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn(
                "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it."
            )
        self.num_processes = num_processes

        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if max_nb_epochs is not None:
            rank_zero_warn(
                "Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.max_nb_epochs = max_nb_epochs

        self.min_epochs = min_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if min_nb_epochs is not None:
            rank_zero_warn(
                "Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.min_nb_epochs = min_nb_epochs

        self.max_steps = max_steps
        self.min_steps = min_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_sanity_val_steps is not None:
            rank_zero_warn(
                "Argument `nb_sanity_val_steps` has renamed to "
                "`num_sanity_val_steps` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.nb_sanity_val_steps = nb_sanity_val_steps

        # Backward compatibility, TODO: remove in v0.9.0
        if print_nan_grads:
            rank_zero_warn(
                "Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
                " NaN grads will be printed automatically when detected.",
                DeprecationWarning)

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.auto_lr_find = auto_lr_find
        self.auto_scale_batch_size = auto_scale_batch_size
        self._is_data_prepared = False
        self.replace_sampler_ddp = replace_sampler_ddp

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.terminate_on_nan = terminate_on_nan
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 0
            self.max_epochs = 1
            log.info('Running in fast_dev_run mode: will run a full train,'
                     ' val and test loop using a single batch')

        # set default save path if user didn't provide one
        self.default_root_dir = default_root_dir

        # Backward compatibility, TODO: remove in v0.8.0
        if default_save_path is not None:
            self.default_root_dir = default_save_path

        if self.default_root_dir is None:
            self.default_root_dir = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningAccum(window_length=20)
        self.batch_idx = 0
        self.progress_bar_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.interrupted = False

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = pick_multiple_gpus(gpus)
        else:
            self.gpus = gpus

        self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
        self.root_gpu = determine_root_gpu_device(
            self.data_parallel_device_ids)
        self.root_device = torch.device("cpu")

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend)

        # override dist backend when using tpus
        if self.on_tpu:
            self.init_tpu()

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.configure_slurm_ddp(self.num_nodes)
        self.node_rank = self.determine_ddp_node_rank()

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        # backward compatibility
        if show_progress_bar is not None:
            self.show_progress_bar = show_progress_bar

        self._progress_bar_callback = self.configure_progress_bar(
            progress_bar_refresh_rate, process_position)

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        # backward compatibility
        if add_row_log_interval is not None:
            rank_zero_warn(
                "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.overfit_pct = overfit_pct
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.autocast_original_forward = None
        self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(
            torch.cuda.amp, "autocast")
        self.precision = precision
        self.scaler = None

        # TODO: remove for v0.8.0
        self.amp_level = amp_level
        self.init_amp(use_amp)

        self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv(
            'KAGGLE_URL_BASE')

        # Callback system
        self.on_init_end()
Ejemplo n.º 10
0
    def __init__(
            self,
            logger=True,
            checkpoint_callback=True,
            early_stop_callback=None,
            default_save_path=None,
            gradient_clip_val=0,
            gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
            process_position=0,
            nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
            num_nodes=1,
            gpus=None,
            log_gpu_memory=None,
            show_progress_bar=True,
            overfit_pct=0.0,
            track_grad_norm=-1,
            check_val_every_n_epoch=1,
            fast_dev_run=False,
            accumulate_grad_batches=1,
            max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            max_epochs=1000,
            min_epochs=1,
            train_percent_check=1.0,
            val_percent_check=1.0,
            test_percent_check=1.0,
            val_check_interval=1.0,
            log_save_interval=100,
            row_log_interval=10,
            add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
            distributed_backend=None,
            use_amp=False,
            print_nan_grads=False,
            weights_summary='full',
            weights_save_path=None,
            amp_level='O1',
            nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
            num_sanity_val_steps=5,
            truncated_bptt_steps=None,
            resume_from_checkpoint=None,
            profiler=None):
        r"""

        Customize every aspect of training via flags

        Args:
            logger (:class:`.Logger`): Logger for experiment tracking.
                Example::

                    from pytorch_lightning.loggers import TensorBoardLogger

                    # default logger used by trainer
                    logger = TensorBoardLogger(
                        save_dir=os.getcwd(),
                        version=self.slurm_job_id,
                        name='lightning_logs'
                    )

                    Trainer(logger=logger)

            checkpoint_callback (:class:`CheckpointCallback`): Callback for checkpointing.
                Example::

                    from pytorch_lightning.callbacks import ModelCheckpoint

                    # default used by the Trainer
                    checkpoint_callback = ModelCheckpoint(
                        filepath=os.getcwd(),
                        save_best_only=True,
                        verbose=True,
                        monitor='val_loss',
                        mode='min',
                        prefix=''
                    )

                    trainer = Trainer(checkpoint_callback=checkpoint_callback)

            early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
                set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
                Will raise an error if ``'val_loss'`` is not found.
                If set to ``False``, then early stopping will be disabled.
                If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
                If ``'val_loss'`` is not found will work as if early stopping is disabled.
                Default: ``None``.
                Example::

                    from pytorch_lightning.callbacks import EarlyStopping

                    # default used by the Trainer
                    early_stop_callback = EarlyStopping(
                        monitor='val_loss',
                        patience=3,
                        strict=False,
                        verbose=False,
                        mode='min'
                    )

                    trainer = Trainer(early_stop_callback=early_stop_callback)

            default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed
                Example::

                    # default used by the Trainer
                    trainer = Trainer(default_save_path=os.getcwd())

            gradient_clip_val (float): 0 means don't clip.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(gradient_clip_val=0.0)

            gradient_clip (int):
                .. deprecated:: 0.5.0
                    Use `gradient_clip_val` instead. Will remove 0.8.0.

            process_position (int): orders the tqdm bar when running multiple models on same machine.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(process_position=0)

            num_nodes (int): number of GPU nodes for distributed training.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(num_nodes=1)

                    # to train on 8 nodes
                    trainer = Trainer(num_nodes=8)

            nb_gpu_nodes (int):
                .. deprecated:: 0.5.0
                    Use `num_nodes` instead. Will remove 0.8.0.

            gpus (list|str|int): Which GPUs to train on.
                Example::

                    # default used by the Trainer (ie: train on CPU)
                    trainer = Trainer(gpus=None)

                    # int: train on 2 gpus
                    trainer = Trainer(gpus=2)

                    # list: train on GPUs 1, 4 (by bus ordering)
                    trainer = Trainer(gpus=[1, 4])
                    trainer = Trainer(gpus='1, 4') # equivalent

                    # -1: train on all gpus
                    trainer = Trainer(gpus=-1)
                    trainer = Trainer(gpus='-1') # equivalent

                    # combine with num_nodes to train on multiple GPUs across nodes
                    trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total

            log_gpu_memory (str): None, 'min_max', 'all'. Might slow performance
                because it uses the output of nvidia-smi.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(log_gpu_memory=None)

                    # log all the GPUs (on master node only)
                    trainer = Trainer(log_gpu_memory='all')

                    # log only the min and max memory on the master node
                    trainer = Trainer(log_gpu_memory='min_max')

            show_progress_bar (bool): If true shows tqdm progress bar
                Example::

                    # default used by the Trainer
                    trainer = Trainer(show_progress_bar=True)

            overfit_pct (float): uses this much data of all datasets.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(overfit_pct=0.0)

                    # use only 1% of the train, test, val datasets
                    trainer = Trainer(overfit_pct=0.01)

            track_grad_norm (int): -1 no tracking. Otherwise tracks that norm
                Example::

                    # default used by the Trainer
                    trainer = Trainer(track_grad_norm=-1)

                    # track the 2-norm
                    trainer = Trainer(track_grad_norm=2)

            check_val_every_n_epoch (int): Check val every n train epochs.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(check_val_every_n_epoch=1)

                    # run val loop every 10 training epochs
                    trainer = Trainer(check_val_every_n_epoch=10)

            fast_dev_run (bool): runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).
                Example::

                    # default used by the Trainer
                    trainer = Trainer(fast_dev_run=False)

                    # runs 1 train, val, test  batch and program ends
                    trainer = Trainer(fast_dev_run=True)

            accumulate_grad_batches (int|dict): Accumulates grads every k batches or as set up in the dict.
                Example::

                    # default used by the Trainer (no accumulation)
                    trainer = Trainer(accumulate_grad_batches=1)

                    # accumulate every 4 batches (effective batch size is batch*4)
                    trainer = Trainer(accumulate_grad_batches=4)

                    # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
                    trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})

            max_epochs (int): Stop training once this number of epochs is reached.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(max_epochs=1000)

            max_nb_epochs (int):
                .. deprecated:: 0.5.0
                    Use `max_epochs` instead. Will remove 0.8.0.

            min_epochs (int): Force training for at least these many epochs
                Example::

                    # default used by the Trainer
                    trainer = Trainer(min_epochs=1)

            min_nb_epochs (int):
                .. deprecated:: 0.5.0
                    Use `min_nb_epochs` instead. Will remove 0.8.0.

            train_percent_check (int): How much of training dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(train_percent_check=1.0)

                    # run through only 25% of the training set each epoch
                    trainer = Trainer(train_percent_check=0.25)

            val_percent_check (int): How much of validation dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(val_percent_check=1.0)

                    # run through only 25% of the validation set each epoch
                    trainer = Trainer(val_percent_check=0.25)

            test_percent_check (int): How much of test dataset to check.
                Useful when debugging or testing something that happens at the end of an epoch.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(test_percent_check=1.0)

                    # run through only 25% of the test set each epoch
                    trainer = Trainer(test_percent_check=0.25)

            val_check_interval (float|int): How often within one training epoch to check the validation set
                If float, % of tng epoch. If int, check every n batch
                Example::

                    # default used by the Trainer
                    trainer = Trainer(val_check_interval=1.0)

                    # check validation set 4 times during a training epoch
                    trainer = Trainer(val_check_interval=0.25)

                    # check validation set every 1000 training batches
                    # use this when using iterableDataset and your dataset has no length
                    # (ie: production cases with streaming data)
                    trainer = Trainer(val_check_interval=1000)

            log_save_interval (int): Writes logs to disk this often
                Example::

                    # default used by the Trainer
                    trainer = Trainer(log_save_interval=100)

            row_log_interval (int): How often to add logging rows (does not write to disk)
                Example::

                    # default used by the Trainer
                    trainer = Trainer(row_log_interval=10)

            add_row_log_interval (int):
                .. deprecated:: 0.5.0
                    Use `row_log_interval` instead. Will remove 0.8.0.

            distributed_backend (str): The distributed backend to use.
                Options: 'dp', 'ddp', 'ddp2'.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(distributed_backend=None)

                    # dp = DataParallel (split a batch onto k gpus on same machine).
                    trainer = Trainer(gpus=2, distributed_backend='dp')

                    # ddp = DistributedDataParallel
                    # Each gpu trains by itself on a subset of the data.
                    # Gradients sync across all gpus and all machines.
                    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')

                    # ddp2 = DistributedDataParallel + dp
                    # behaves like dp on every node
                    # syncs gradients across nodes like ddp
                    # useful for things like increasing the number of negative samples
                    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')

            use_amp (bool): If true uses apex for 16bit precision
                Example::

                    # default used by the Trainer
                    trainer = Trainer(use_amp=False)

            print_nan_grads (bool): Prints gradients with nan values
                Example::

                    # default used by the Trainer
                    trainer = Trainer(print_nan_grads=False)

            weights_summary (str): Prints a summary of the weights when training begins.
                Options: 'full', 'top', None.
                Example::

                    # default used by the Trainer (ie: print all weights)
                    trainer = Trainer(weights_summary='full')

                    # print only the top level modules
                    trainer = Trainer(weights_summary='top')

                    # don't print a summary
                    trainer = Trainer(weights_summary=None)

            weights_save_path (str): Where to save weights if specified.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(weights_save_path=os.getcwd())

                    # save to your custom path
                    trainer = Trainer(weights_save_path='my/path')

                    # if checkpoint callback used, then overrides the weights path
                    # **NOTE: this saves weights to some/path NOT my/path
                    checkpoint_callback = ModelCheckpoint(filepath='some/path')
                    trainer = Trainer(
                        checkpoint_callback=checkpoint_callback,
                        weights_save_path='my/path'
                    )

            amp_level (str): The optimization level to use (O1, O2, etc...).
                Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
                Example::

                    # default used by the Trainer
                    trainer = Trainer(amp_level='O1')

            num_sanity_val_steps (int): Sanity check runs n batches of val before starting the training routine.
                This catches any bugs in your validation without having to wait for the first validation check.
                The Trainer uses 5 steps by default. Turn it off or modify it here.
                Example::

                    # default used by the Trainer
                    trainer = Trainer(num_sanity_val_steps=5)

                    # turn it off
                    trainer = Trainer(num_sanity_val_steps=0)

            nb_sanity_val_steps (int):
                .. deprecated:: 0.5.0
                    Use `num_sanity_val_steps` instead. Will remove 0.8.0.

            truncated_bptt_steps (int): Truncated back prop breaks performs backprop every k steps of
                a much longer sequence If this is enabled, your batches will automatically get truncated
                and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
                dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
                recurrent network trajectories."
                <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
                Example::

                    # default used by the Trainer (ie: disabled)
                    trainer = Trainer(truncated_bptt_steps=None)

                    # backprop every 5 steps in a batch
                    trainer = Trainer(truncated_bptt_steps=5)

                Using this feature requires updating your LightningModule's `training_step()` to include
                a `hiddens` arg.


            resume_from_checkpoint (str): To resume training from a specific checkpoint pass in the path here.k
                Example::

                    # default used by the Trainer
                    trainer = Trainer(resume_from_checkpoint=None)

                    # resume from a specific checkpoint
                    trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
            profiler (BaseProfiler):  To profile individual steps during training and assist in
                identifying bottlenecks.
                Example::

                    from pytorch_lightning.profiler import Profiler, AdvancedProfiler

                    # default used by the Trainer
                    trainer = Trainer(profiler=None)

                    # to profile standard training events
                    trainer = Trainer(profiler=True)

                    # equivalent to profiler=True
                    profiler = Profiler()
                    trainer = Trainer(profiler=profiler)

                    # advanced profiler for function-level stats
                    profiler = AdvancedProfiler()
                    trainer = Trainer(profiler=profiler)

        .. warning:: Following arguments become deprecated and they will be removed in v0.8.0:

            - `nb_sanity_val_steps`

        """

        # Transfer params
        # Backward compatibility
        if nb_gpu_nodes is not None:
            warnings.warn(
                "`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not num_nodes:  # in case you did not set the proper value
                num_nodes = nb_gpu_nodes
        self.num_gpu_nodes = num_nodes

        self.log_gpu_memory = log_gpu_memory

        # Backward compatibility
        if gradient_clip is not None:
            warnings.warn(
                "`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not gradient_clip_val:  # in case you did not set the proper value
                gradient_clip_val = gradient_clip
        self.gradient_clip_val = gradient_clip_val

        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
        self.process_position = process_position
        self.weights_summary = weights_summary

        # Backward compatibility
        if max_nb_epochs is not None:
            warnings.warn(
                "`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not max_epochs:  # in case you did not set the proper value
                max_epochs = max_nb_epochs
        self.max_epochs = max_epochs

        # Backward compatibility
        if min_nb_epochs is not None:
            warnings.warn(
                "`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not min_epochs:  # in case you did not set the proper value
                min_epochs = min_nb_epochs
        self.min_epochs = min_epochs

        # Backward compatibility
        if nb_sanity_val_steps is not None:
            warnings.warn(
                "`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not num_sanity_val_steps:  # in case you did not set the proper value
                num_sanity_val_steps = nb_sanity_val_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        self.print_nan_grads = print_nan_grads
        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 1
            self.max_epochs = 1
            m = '''
            Running in fast_dev_run mode: will run a full train,
            val loop using a single batch
            '''
            log.info(m)

        # set default save path if user didn't provide one
        self.default_save_path = default_save_path
        if self.default_save_path is None:
            self.default_save_path = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = []
        self.avg_loss = 0
        self.batch_idx = 0
        self.tqdm_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.get_train_dataloader = None
        self.get_test_dataloaders = None
        self.get_val_dataloaders = None
        self.is_iterable_train_dataloader = False

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.global_step = 0
        self.current_epoch = 0
        self.total_batches = 0

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = Profiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        self.reduce_lr_on_plateau_scheduler = None

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # allow int, string and gpu list
        self.data_parallel_device_ids = parse_gpu_ids(gpus)
        self.root_gpu = determine_root_gpu_device(
            self.data_parallel_device_ids)

        # distributed backend choice
        self.use_ddp = False
        self.use_ddp2 = False
        self.use_dp = False
        self.single_gpu = False
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend, num_nodes)

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.node_rank = 0
        self.configure_slurm_ddp(num_nodes)

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        # can't init progress bar here because starting a new process
        # means the progress_bar won't survive pickling
        self.show_progress_bar = show_progress_bar

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        # backward compatibility
        if add_row_log_interval is not None:
            warnings.warn(
                "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                " and will be removed in v0.8.0", DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # 16 bit mixed precision training using apex
        self.amp_level = amp_level
        self.init_amp(use_amp)
Ejemplo n.º 11
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: List[Callback] = [],
            default_save_path: Optional[str] = None,
            gradient_clip_val: float = 0,
            gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
            process_position: int = 0,
            nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
            num_nodes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            num_tpu_cores: Optional[int] = None,
            log_gpu_memory: Optional[str] = None,
            show_progress_bar: bool = True,
            progress_bar_refresh_rate: int = 50,
            overfit_pct: float = 0.0,
            track_grad_norm: int = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            train_percent_check: float = 1.0,
            val_percent_check: float = 1.0,
            test_percent_check: float = 1.0,
            val_check_interval: float = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 10,
            add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
            distributed_backend: Optional[str] = None,
            use_amp=False,  # backward compatible, todo: remove in v0.9.0
            precision: int = 32,
            print_nan_grads: bool = False,
            weights_summary: str = 'full',
            weights_save_path: Optional[str] = None,
            amp_level: str = 'O1',
            nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
            num_sanity_val_steps: int = 5,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[BaseProfiler] = None,
            benchmark: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            **kwargs):
        r"""

        Customize every aspect of training via flags

        Args:
            logger: Logger (or iterable collection of loggers) for experiment tracking.

            checkpoint_callback: Callback for checkpointing.

            early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):

            callbacks: Add a list of callbacks.

            default_save_path: Default path for logs and weights when no logger/ckpt_callback passed

            gradient_clip_val: 0 means don't clip.

            gradient_clip:
                .. warning:: deprecated 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0.

            process_position: orders the tqdm bar when running multiple models on same machine.

            num_nodes: number of GPU nodes for distributed training.

            nb_gpu_nodes:
                .. warning:: .. deprecated:: 0.7.0
                    Use `num_nodes` instead. Will remove 0.9.0.

            gpus: Which GPUs to train on.

            num_tpu_cores: How many TPU cores to train on (1 or 8).

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            show_progress_bar: If true shows tqdm progress bar

            progress_bar_refresh_rate: How often to refresh progress bar (in steps)

            track_grad_norm: -1 no tracking. Otherwise tracks that norm

            check_val_every_n_epoch: Check val every n train epochs.

            fast_dev_run: runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            max_epochs: Stop training once this number of epochs is reached.

            max_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0
                    Use `max_epochs` instead. Will remove 0.9.0.

            min_epochs: Force training for at least these many epochs

            min_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0
                    Use `min_epochs` instead. Will remove 0.9.0.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            train_percent_check: How much of training dataset to check.

            val_percent_check: How much of validation dataset to check.

            test_percent_check: How much of test dataset to check.

            val_check_interval: How often within one training epoch to check the validation set

            log_save_interval: Writes logs to disk this often

            row_log_interval: How often to add logging rows (does not write to disk)

            add_row_log_interval:
                .. warning:: .. deprecated:: 0.7.0
                    Use `row_log_interval` instead. Will remove 0.9.0.

            distributed_backend: The distributed backend to use.

            use_amp:
                .. warning:: .. deprecated:: 0.7.0
                    Use `precision` instead. Will remove 0.9.0.

            precision: Full precision (32), half precision (16).

            print_nan_grads: Prints gradients with nan values

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified.

            amp_level: The optimization level to use (O1, O2, etc...).

            num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.

            nb_sanity_val_steps:
                .. warning:: .. deprecated:: 0.7.0
                    Use `num_sanity_val_steps` instead. Will remove 0.8.0.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k

            profiler:  To profile individual steps during training and assist in

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch

            benchmark: If true enables cudnn.benchmark.
        """

        # Init callbacks
        self.callbacks = callbacks
        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        if benchmark:
            torch.backends.cudnn.benchmark = True

        # Transfer params
        self.num_nodes = num_nodes
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_gpu_nodes is not None:
            warnings.warn(
                "Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.num_gpu_nodes = nb_gpu_nodes
        self.log_gpu_memory = log_gpu_memory

        self.gradient_clip_val = gradient_clip_val
        # Backward compatibility, TODO: remove in v0.8.0
        if gradient_clip is not None:
            warnings.warn(
                "Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.gradient_clip = gradient_clip

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
        self.progress_bar_refresh_rate = progress_bar_refresh_rate
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

        # tpu config
        self.on_tpu = num_tpu_cores is not None
        self.num_tpu_cores = num_tpu_cores
        assert num_tpu_cores in [1, 8,
                                 None], 'num_tpu_cores can only be 1 or 8'

        self.process_position = process_position
        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if max_nb_epochs is not None:
            warnings.warn(
                "Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.max_nb_epochs = max_nb_epochs

        self.min_epochs = min_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if min_nb_epochs is not None:
            warnings.warn(
                "Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.min_nb_epochs = min_nb_epochs

        self.max_steps = max_steps
        self.min_steps = min_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_sanity_val_steps is not None:
            warnings.warn(
                "Argument `nb_sanity_val_steps` has renamed to "
                "`num_sanity_val_steps` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            self.nb_sanity_val_steps = nb_sanity_val_steps
        self.print_nan_grads = print_nan_grads
        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 1
            self.max_epochs = 1
            m = '''
            Running in fast_dev_run mode: will run a full train,
            val loop using a single batch
            '''
            log.info(m)

        # set default save path if user didn't provide one
        self.default_save_path = default_save_path
        if self.default_save_path is None:
            self.default_save_path = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = []
        self.avg_loss = 0
        self.batch_idx = 0
        self.tqdm_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.global_step = 0
        self.current_epoch = 0
        self.total_batches = 0

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = Profiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # allow int, string and gpu list
        self.gpus = gpus
        self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
        self.root_gpu = determine_root_gpu_device(
            self.data_parallel_device_ids)
        root_device = (torch.device("cuda", self.root_gpu)
                       if self.root_gpu else torch.device("cpu"))
        torch.cuda.set_device(root_device)

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.use_ddp = False
        self.use_ddp2 = False
        self.use_dp = False
        self.single_gpu = False
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend, self.num_nodes)

        # override dist backend when using tpus
        if self.on_tpu:
            self.init_tpu()
            self.current_tpu_idx = None

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.node_rank = 0
        self.configure_slurm_ddp(self.num_nodes)

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        # can't init progress bar here because starting a new process
        # means the progress_bar won't survive pickling
        self.show_progress_bar = show_progress_bar

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        # backward compatibility
        if add_row_log_interval is not None:
            warnings.warn(
                "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                " and this method will be removed in v0.8.0",
                DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.overfit_pct = overfit_pct
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # 16 bit mixed precision training using apex
        self.amp_level = amp_level
        self.precision = precision

        assert self.precision in (16,
                                  32), 'only 32 or 16 bit precision supported'

        if self.precision == 16 and self.num_tpu_cores is None:
            use_amp = True
        self.init_amp(use_amp)

        # Callback system
        self.on_init_end()