コード例 #1
0
    def __call__(cls, *args, **kwargs):
        """A wrapper for LightningDataModule that:

            1. Runs user defined subclass's __init__
            2. Assures prepare_data() runs on rank 0
            3. Lets you check prepare_data and setup to see if they've been called
        """

        # Track prepare_data calls and make sure it runs on rank zero
        cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
        # Track setup calls
        cls.setup = track_data_hook_calls(cls.setup)

        # Get instance of LightningDataModule by mocking its __init__ via __call__
        obj = type.__call__(cls, *args, **kwargs)

        return obj
コード例 #2
0
def get_logger(name=__name__) -> logging.Logger:
    """Initializes multi-GPU-friendly python command line logger."""

    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in (
            "debug",
            "info",
            "warning",
            "error",
            "exception",
            "fatal",
            "critical",
    ):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger
コード例 #3
0
 def wrap_functions_into_rank_zero_only(self):
     self.start = rank_zero_only(self.start)
     self.stop = rank_zero_only(self.stop)
     self.summary = rank_zero_only(self.summary)
     self.describe = rank_zero_only(self.describe)
コード例 #4
0
 def __init__(self):
     super().__init__()
     self.start = rank_zero_only(self.start)
     self.stop = rank_zero_only(self.stop)
     self.summary = rank_zero_only(self.summary)
コード例 #5
0
from abc import abstractmethod
from argparse import Namespace
from idlelib.config import _warn
from typing import Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities import AttributeDict, rank_zero_only

from domain.base import Hyperparameters
from properties import APPLICATION_PROPERTIES

rank_zero_warn = rank_zero_only(_warn)


class DataModuleBase(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        self.arg = Hyperparameters(kwargs)
        self._convert_arguments(arg=self.arg)
        super(DataModuleBase, self).__init__()

        # Dataset
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    @abstractmethod
    def prepare_data(self, *args, **kwargs):
        pass

    @abstractmethod
    def setup(self, stage: Optional[str] = None):
コード例 #6
0
ファイル: train.py プロジェクト: zju3dv/LoFTR
def main():
    # parse arguments
    args = parse_args()
    rank_zero_only(pprint.pprint)(vars(args))

    # init default-cfg and merge it with the main- and data-cfg
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    pl.seed_everything(config.TRAINER.SEED)  # reproducibility
    # TODO: Use different seeds for each dataloader workers
    # This is needed for data augmentation

    # scale lr and warmup-step automatically
    args.gpus = _n_gpus = setup_gpus(args.gpus)
    config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
    config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
    _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
    config.TRAINER.SCALING = _scaling
    config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
    config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP /
                                            _scaling)

    # lightning module
    profiler = build_profiler(args.profiler_name)
    model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
    loguru_logger.info(f"LoFTR LightningModule initialized!")

    # lightning data
    data_module = MultiSceneDataModule(args, config)
    loguru_logger.info(f"LoFTR DataModule initialized!")

    # TensorBoard Logger
    logger = TensorBoardLogger(save_dir='logs/tb_logs',
                               name=args.exp_name,
                               default_hp_metric=False)
    ckpt_dir = Path(logger.log_dir) / 'checkpoints'

    # Callbacks
    # TODO: update ModelCheckpoint to monitor multiple metrics
    ckpt_callback = ModelCheckpoint(
        monitor='auc@10',
        verbose=True,
        save_top_k=5,
        mode='max',
        save_last=True,
        dirpath=str(ckpt_dir),
        filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
    lr_monitor = LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor]
    if not args.disable_ckpt:
        callbacks.append(ckpt_callback)

    # Lightning Trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        plugins=DDPPlugin(find_unused_parameters=False,
                          num_nodes=args.num_nodes,
                          sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
        gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
        callbacks=callbacks,
        logger=logger,
        sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
        replace_sampler_ddp=False,  # use custom sampler
        reload_dataloaders_every_epoch=False,  # avoid repeated samples!
        weights_summary='full',
        profiler=profiler)
    loguru_logger.info(f"Trainer initialized!")
    loguru_logger.info(f"Start training!")
    trainer.fit(model, datamodule=data_module)