def main():
    # ----------
    # seed
    # ----------
    pl.seed_everything(0)

    # ----------
    # args
    # ----------
    parser = ArgumentParser()
    parser.add_argument("--exp_name", default="Test")
    parser.add_argument("--model_name", default="tf_efficientdet_d0")
    parser.add_argument(
        "--dataset_dir", default="../dataset", metavar="DIR", help="path to dataset"
    )
    parser.add_argument("--batch_size", default=4, type=int)
    parser.add_argument("--num_workers", default=2, type=int)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--impactonly", action="store_true")
    parser.add_argument("--impactdefinitive", action="store_true")
    parser.add_argument("--overlap", default=None, type=int)
    parser.add_argument("--oversample", action="store_true")
    parser.add_argument("--seqmode", action="store_true")
    parser.add_argument("--fullsizeimage", action="store_true")
    parser.add_argument("--anchor_scale", default=4, type=int)
    parser.add_argument("--fold_index", default=0, type=int)
    parser = ImpactDetector.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # ----------
    # data
    # ----------
    dm = ImpactDataModule(
        data_dir=args.dataset_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        impactonly=args.impactonly,
        impactdefinitive=args.impactdefinitive,
        overlap=args.overlap,
        oversample=args.oversample,
        seqmode=args.seqmode,
        fullsizeimage=args.fullsizeimage,
        fold_index=args.fold_index,
    )

    # ----------
    # model
    # ----------
    dict_args = vars(args)
    impact_detector = ImpactDetector(**dict_args)
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        filename="impact-detector-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    # ----------
    # logger
    # ----------
    if not args.debug:
        neptune_logger = NeptuneLogger(
            api_key=os.environ["NEPTUNE_API_TOKEN"],
            project_name="sunghyun.jun/nfl-impact",
            experiment_name=args.exp_name,
            params={
                "model_name": args.model_name,
                "batch_size": args.batch_size,
                "num_workers": args.num_workers,
                "init_lr": args.init_lr,
                "weight_decay": args.weight_decay,
            },
            tags=["pytorch-lightning"],
        )

        lr_monitor = LearningRateMonitor(logging_interval="step")

    # ----------
    # training
    # ----------
    if not args.debug:
        trainer = pl.Trainer.from_argparse_args(
            args, logger=neptune_logger, callbacks=[checkpoint_callback, lr_monitor]
        )
    else:
        args.max_epochs = 1
        args.limit_train_batches = 10
        args.limit_val_batches = 10
        trainer = pl.Trainer.from_argparse_args(args)

    trainer.fit(impact_detector, dm)
def test_manual_optimization_and_accumulated_gradient(tmpdir):
    """
    This test verify that in `automatic_optimization=False`,
    step is being called only when we shouldn't accumulate.
    """
    seed_everything(234)

    class ExtendedModel(BoringModel):

        count = 1
        called = collections.defaultdict(int)
        detach = False

        def __init__(self):
            super().__init__()
            self.automatic_optimization = False

        @property
        def should_update(self):
            return self.count % 2 == 0

        @property
        def should_have_updated(self):
            return self.count % 4 == 0

        @property
        def has_gradient(self):
            return self.layer.weight.grad is not None

        def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
            self.called["on_train_batch_start"] += 1
            self.weight_before = self.layer.weight.clone()

        def training_step(self, batch, batch_idx):
            self.called["training_step"] += 1
            opt = self.optimizers()
            output = self.layer(batch)

            loss = self.loss(batch, output)
            loss /= loss.clone().detach()
            loss *= 0.1

            if self.should_update:

                self.manual_backward(loss, opt)
                if self.should_have_updated:
                    opt.step()
                    opt.zero_grad()

            return loss.detach() if self.detach else loss

        def on_train_batch_end(self, outputs, batch, batch_idx,
                               dataloader_idx):
            self.called["on_train_batch_end"] += 1
            after_before = self.layer.weight.clone()
            if self.should_update and self.should_have_updated:
                assert not torch.equal(self.weight_before,
                                       after_before), self.count
                assert torch.all(self.layer.weight.grad == 0)
            else:
                assert torch.equal(self.weight_before, after_before)
                if self.count > 1:
                    if self.count % 4 == 1:
                        assert torch.all(self.layer.weight.grad == 0)
                    else:
                        assert torch.sum(self.layer.weight.grad) != 0
            self.count += 1

        def on_train_epoch_end(self, *_, **__):
            assert self.called["training_step"] == 20
            assert self.called["on_train_batch_start"] == 20
            assert self.called["on_train_batch_end"] == 20

    model = ExtendedModel()
    model.training_step_end = None
    model.training_epoch_end = None

    trainer = Trainer(
        max_epochs=1,
        default_root_dir=tmpdir,
        limit_train_batches=20,
        limit_test_batches=0,
        limit_val_batches=0,
        precision=16,
        amp_backend='native',
        accumulate_grad_batches=4,
        gpus=1,
    )
    trainer.fit(model)
Exemplo n.º 3
0
    if args.affinity != "disabled":
        set_affinity(int(os.getenv("LOCAL_RANK", "0")),
                     args.gpus,
                     mode=args.affinity)

    # Limit number of CPU threads
    os.environ["OMP_NUM_THREADS"] = "1"
    # Set device limit on the current device cudaLimitMaxL2FetchGranularity = 0x05
    _libcudart = ctypes.CDLL("libcudart.so")
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    assert pValue.contents.value == 128

    set_cuda_devices(args)
    seed_everything(args.seed)
    data_module = DataModule(args)
    data_module.prepare_data()
    data_module.setup()
    ckpt_path = verify_ckpt_path(args)

    callbacks = None
    model_ckpt = None
    if args.benchmark:
        model = NNUnet(args)
        batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
        filnename = args.logname if args.logname is not None else "perf1.json"
        callbacks = [
            LoggingCallback(
                log_dir=args.results,
                filnename=filnename,
Exemplo n.º 4
0
Options:
    --config <model config path>  Path to YAML file for model configuration  [default: pretrained_model/deterministic_conv/config.yaml] [type: path]
    --weights-filepath <weights file path>  Path to weights file for model  [default: pretrained_model/deterministic_conv/weights.ckpt] [type: path]    
    --image-path <image path> Path to image filepath for inference  [default: images/cifar_frog(6).png]
            
    -h --help  Show this.
"""
import pytorch_lightning
import torch
from omegaconf import DictConfig
from PIL import Image

from src.runner.predictor import Predictor
from src.utils import get_config

pytorch_lightning.seed_everything(777)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def infer(hparams: dict):
    weight_filepath = str(hparams.get("--weights-filepath"))
    image_path = str(hparams.get("--image-path"))

    config_list = ["--config"]
    config: DictConfig = get_config(hparams=hparams, options=config_list)

    predictor = Predictor(config=config)
    predictor.load_state_dict(torch.load(weight_filepath)["state_dict"])
    predictor.eval()
Exemplo n.º 5
0
def test_against_sklearn(sklearn_metric: Callable,
                         torch_class_metric: Metric) -> None:
    """Compare PL metrics to sklearn version. """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed_everything(0)

    rounds = 20
    sizes = [1, 4, 10, 100]
    batch_sizes = [1, 4, 10]
    query_without_relevant_docs_options = ['skip', 'pos', 'neg']

    def compute_sklearn_metric(target: List[np.ndarray],
                               preds: List[np.ndarray],
                               behaviour: str) -> torch.Tensor:
        """ Compute sk metric with multiple iterations using the base `sklearn_metric`. """
        sk_results = []
        kwargs = {'device': device, 'dtype': torch.float32}

        for b, a in zip(target, preds):
            res = sklearn_metric(b, a)

            if math.isnan(res):
                if behaviour == 'skip':
                    pass
                elif behaviour == 'pos':
                    sk_results.append(torch.tensor(1.0, **kwargs))
                else:
                    sk_results.append(torch.tensor(0.0, **kwargs))
            else:
                sk_results.append(torch.tensor(res, **kwargs))
        if len(sk_results) > 0:
            sk_results = torch.stack(sk_results).mean()
        else:
            sk_results = torch.tensor(0.0, **kwargs)

        return sk_results

    def do_test(batch_size: int, size: int) -> None:
        """ For each possible behaviour of the metric, check results are correct. """
        for behaviour in query_without_relevant_docs_options:
            metric = torch_class_metric(query_without_relevant_docs=behaviour)
            shape = (size, )

            indexes = []
            preds = []
            target = []

            for i in range(batch_size):
                indexes.append(np.ones(shape, dtype=int) * i)
                preds.append(np.random.randn(*shape))
                target.append(np.random.randn(*shape) > 0)

            sk_results = compute_sklearn_metric(target, preds, behaviour)

            indexes_tensor = torch.cat([torch.tensor(i) for i in indexes])
            preds_tensor = torch.cat([torch.tensor(p) for p in preds])
            target_tensor = torch.cat([torch.tensor(t) for t in target])

            # lets assume data are not ordered
            perm = torch.randperm(indexes_tensor.nelement())
            indexes_tensor = indexes_tensor.view(-1)[perm].view(
                indexes_tensor.size())
            preds_tensor = preds_tensor.view(-1)[perm].view(
                preds_tensor.size())
            target_tensor = target_tensor.view(-1)[perm].view(
                target_tensor.size())

            # shuffle ids to require also sorting of documents ability from the lightning metric
            pl_result = metric(indexes_tensor, preds_tensor, target_tensor)

            assert torch.allclose(sk_results.float(),
                                  pl_result.float(),
                                  equal_nan=True)

    for batch_size in batch_sizes:
        for size in sizes:
            for _ in range(rounds):
                do_test(batch_size, size)
"""
Runs a model on the CPU on a single node.
"""
import os
from argparse import ArgumentParser

from pytorch_lightning import Trainer, seed_everything
from pl_examples.models.lightning_template import LightningTemplateModel

seed_everything(234)


def main(args):
    """ Main training routine specific for this project. """
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = LightningTemplateModel(**vars(args))

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    trainer = Trainer.from_argparse_args(args)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)


def run_cli():
Exemplo n.º 7
0
def main():
    seed_everything(42)
    args = parse_args(sys.argv[1:])
    trainer = train(**vars(args))
    test(trainer)
Exemplo n.º 8
0
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers,
                                                   stop_optimizer, stop_epoch,
                                                   stop_batch):
    """Test that Lightning can resume from a point where a training_step failed while in the middle of processing
    several optimizer steps for one batch.

    The test asserts that we end up with the same trained weights as if no failure occurred.
    """

    n_batches = 3
    n_epochs = 2

    def _assert_optimizer_sequence(method_mock, expected):
        positional_args = [c[0] for c in method_mock.call_args_list]
        sequence = [arg[3] for arg in positional_args]
        assert sequence == expected

    num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer

    opt_idx_sequence_complete = list(range(
        n_optimizers)) * n_epochs * n_batches  # [0, 1, 2, 0, 1, 2, 0, 1, ...]
    # +1 because we fail inside the closure inside optimizer_step()
    opt_idx_sequence_incomplete = opt_idx_sequence_complete[:(
        num_optimizers_incomplete + 1)]
    opt_idx_sequence_resumed = opt_idx_sequence_complete[
        num_optimizers_incomplete:]

    class MultipleOptimizerModel(BoringModel):
        def training_step(self, batch, batch_idx, optimizer_idx):
            if (fail and self.current_epoch == stop_epoch
                    and batch_idx == stop_batch
                    and optimizer_idx == stop_optimizer):
                raise CustomException
            return super().training_step(batch, batch_idx)

        def configure_optimizers(self):
            return [
                torch.optim.SGD(self.parameters(), lr=0.1)
                for _ in range(n_optimizers)
            ]

    # run without a failure, collect weights
    fail = False
    seed_everything(0)
    model = MultipleOptimizerModel()
    model.training_epoch_end = None
    model.optimizer_step = Mock(wraps=model.optimizer_step)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=n_epochs,
        limit_train_batches=n_batches,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model)
    weights_complete = model.parameters()
    _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)

    # simulate a failure
    fail = True
    seed_everything(0)
    model = MultipleOptimizerModel()
    model.training_epoch_end = None
    model.optimizer_step = Mock(wraps=model.optimizer_step)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=n_epochs,
        limit_train_batches=n_batches,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        logger=False,
        enable_checkpointing=False,
    )
    with pytest.raises(CustomException):
        trainer.fit(model)

    _assert_optimizer_sequence(model.optimizer_step,
                               opt_idx_sequence_incomplete)

    # resume from failure and collect weights
    fail = False
    seed_everything(0)
    model = MultipleOptimizerModel()
    model.training_epoch_end = None
    model.optimizer_step = Mock(wraps=model.optimizer_step)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=n_epochs,
        limit_train_batches=n_batches,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt"))
    weights_resumed = model.parameters()

    # check that the final weights of a resumed run match the weights of a run that never failed
    for w0, w1 in zip(weights_complete, weights_resumed):
        assert torch.allclose(w0, w1)

    _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)
Exemplo n.º 9
0
def set_seed(args, cfg):
    """Sets random seed for reproducibility
    """
    seed = int(args["--seed"]) if args["--seed"] else cfg['experiment']['seed']
    pl.seed_everything(seed)
Exemplo n.º 10
0
            type=int,
            default=4,
            help="how many steps of gradient descent to perform on each batch")
        parser.add_argument(
            "--clip_ratio",
            type=float,
            default=0.2,
            help="hyperparameter for clipping in the policy objective")

        return parent_parser


def main(args) -> None:
    model = PPOLightning(**vars(args))

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)


if __name__ == "__main__":
    cli_lightning_logo()
    pl.seed_everything(0)

    parent_parser = argparse.ArgumentParser(add_help=False)
    parent_parser = pl.Trainer.add_argparse_args(parent_parser)

    parser = PPOLightning.add_model_specific_args(parent_parser)
    args = parser.parse_args()

    main(args)
Exemplo n.º 11
0
def parse_args(argv=None):
    argv = argv or []

    parser = ArgumentParser()

    # add model specific args
    parser = SCAEMNIST.add_model_specific_args(parser)

    # add all the available trainer options to parser
    parser = Trainer.add_argparse_args(parser)

    # add other args
    parser.add_argument('--save_top_k', type=int, default=1)

    args = parser.parse_args(argv)

    return args


if __name__ == '__main__':
    import sys

    from torch_scae_experiments.umonths.hparams import model_params
    print("starting")
    seed_everything(42)

    args = parse_args(sys.argv[1:])

    train(model_params, **vars(args))
Exemplo n.º 12
0
                               lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        parser.add_argument('--input_dim', type=int, default=None)
        parser.add_argument('--bias', default='store_true')
        parser.add_argument('--batch_size', type=int, default=16)
        parser.add_argument('--optimizer', type=str, default='Adam')
        return parser


if __name__ == '__main__':  # pragma: no cover
    from argparse import ArgumentParser
    pl.seed_everything(1234)

    # create dataset
    from sklearn.datasets import load_boston
    X, y = load_boston(return_X_y=True)  # these are numpy arrays
    loaders = SklearnDataModule(X, y)

    # args
    parser = ArgumentParser()
    parser = LinearRegression.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # model
    model = LinearRegression(**vars(args))
Exemplo n.º 13
0
    def __init__(
        self,
        model_class: Optional[Union[Type[LightningModule],
                                    Callable[..., LightningModule]]] = None,
        datamodule_class: Optional[Union[Type[LightningDataModule], Callable[
            ..., LightningDataModule]]] = None,
        save_config_callback: Optional[
            Type[SaveConfigCallback]] = SaveConfigCallback,
        save_config_filename: str = "config.yaml",
        save_config_overwrite: bool = False,
        save_config_multifile: bool = False,
        trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
        trainer_defaults: Optional[Dict[str, Any]] = None,
        seed_everything_default: Optional[int] = None,
        description: str = "pytorch-lightning trainer command line tool",
        env_prefix: str = "PL",
        env_parse: bool = False,
        parser_kwargs: Optional[Union[Dict[str, Any],
                                      Dict[str, Dict[str, Any]]]] = None,
        subclass_mode_model: bool = False,
        subclass_mode_data: bool = False,
        run: bool = True,
    ) -> None:
        """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
        are called / instantiated using a parsed configuration file and / or command line args.

        Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
        A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
        Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.

        For more info, read :ref:`the CLI docs <common/lightning_cli:LightningCLI>`.

        .. warning:: ``LightningCLI`` is in beta and subject to change.

        Args:
            model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a
                callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when
                called. If ``None``, you can pass a registered model with ``--model=MyModel``.
            datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
                callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
                called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
            save_config_callback: A callback class to save the training config.
            save_config_filename: Filename for the config file.
            save_config_overwrite: Whether to overwrite an existing config file.
            save_config_multifile: When input is multiple config files, saved config preserves this structure.
            trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
                callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
            trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
            seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
                seed argument.
            description: Description of the tool shown when running ``--help``.
            env_prefix: Prefix for environment variables.
            env_parse: Whether environment variable parsing is enabled.
            parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
            subclass_mode_model: Whether model can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
            subclass_mode_data: Whether datamodule can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
            run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
                method. If set to ``False``, the trainer and model classes will be instantiated only.
        """
        self.save_config_callback = save_config_callback
        self.save_config_filename = save_config_filename
        self.save_config_overwrite = save_config_overwrite
        self.save_config_multifile = save_config_multifile
        self.trainer_class = trainer_class
        self.trainer_defaults = trainer_defaults or {}
        self.seed_everything_default = seed_everything_default

        self.model_class = model_class
        # used to differentiate between the original value and the processed value
        self._model_class = model_class or LightningModule
        self.subclass_mode_model = (model_class is None) or subclass_mode_model

        self.datamodule_class = datamodule_class
        # used to differentiate between the original value and the processed value
        self._datamodule_class = datamodule_class or LightningDataModule
        self.subclass_mode_data = (datamodule_class is
                                   None) or subclass_mode_data

        main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
            parser_kwargs
            or {},  # type: ignore  # github.com/python/mypy/issues/6463
            {
                "description": description,
                "env_prefix": env_prefix,
                "default_env": env_parse
            },
        )
        self.setup_parser(run, main_kwargs, subparser_kwargs)
        self.parse_arguments(self.parser)

        self.subcommand = self.config["subcommand"] if run else None

        seed = self._get(self.config, "seed_everything")
        if seed is not None:
            seed_everything(seed, workers=True)

        self.before_instantiate_classes()
        self.instantiate_classes()

        if self.subcommand is not None:
            self._run_subcommand(self.subcommand)
Exemplo n.º 14
0
 def configure_seed(self) -> None:
     seed = self.config.training.seed
     seed_everything(seed)
Exemplo n.º 15
0
from pytorch_lightning import seed_everything

from utils import *
from models import CelebAModel
from trainer import get_trainer

parser = ArgumentParser()

parser.add_argument('--config', type=str, help='Config file path')
parser.add_argument('--path', type=str, help='Save path')
parser.add_argument('--gpus', type=str, help='Gpus used')
parser.add_argument('--eval', type=bool, help='Whether only do test')

args = parser.parse_args()
seed_everything(1234)  # reproducibility
debug = False

config = parse_config(args.config)
gpus = [int(x) for x in args.gpus.strip().split(',')]
if not os.path.isdir(args.path):
    os.mkdir(args.path)

criterion = get_loss(config['criterion'])
model = CelebAModel(criterion=criterion,
                    config=config,
                    path=args.path,
                    batch_size=config['batch_size'],
                    **config['model'])
trainer = get_trainer(gpus=gpus,
                      path=args.path,
def cli_main():  # pragma: no-cover
    from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule

    pl.seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument('--dataset',
                        type=str,
                        help='stl10, imagenet',
                        default='stl10')
    parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
    parser.add_argument('--data_path',
                        type=str,
                        help='path to ckpt',
                        default=os.getcwd())

    parser.add_argument("--batch_size",
                        default=64,
                        type=int,
                        help="batch size per gpu")
    parser.add_argument("--num_workers",
                        default=8,
                        type=int,
                        help="num of workers per GPU")
    parser.add_argument("--gpus", default=4, type=int, help="number of GPUs")
    parser.add_argument('--num_epochs',
                        default=100,
                        type=int,
                        help="number of epochs")

    # fine-tuner params
    parser.add_argument('--in_features', type=int, default=2048)
    parser.add_argument('--dropout', type=float, default=0.)
    parser.add_argument('--learning_rate', type=float, default=0.3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--nesterov', type=bool, default=False)
    parser.add_argument('--scheduler_type', type=str, default='cosine')
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--final_lr', type=float, default=0.)

    args = parser.parse_args()

    if args.dataset == 'stl10':
        dm = STL10DataModule(data_dir=args.data_path,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

        dm.train_dataloader = dm.train_dataloader_labeled
        dm.val_dataloader = dm.val_dataloader_labeled
        args.num_samples = 0

        dm.train_transforms = SwAVFinetuneTransform(
            normalize=stl10_normalization(),
            input_height=dm.size()[-1],
            eval_transform=False)
        dm.val_transforms = SwAVFinetuneTransform(
            normalize=stl10_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        args.maxpool1 = False
        args.first_conv = True
    elif args.dataset == 'imagenet':
        dm = ImagenetDataModule(data_dir=args.data_path,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers)

        dm.train_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=False)
        dm.val_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        dm.test_transforms = SwAVFinetuneTransform(
            normalize=imagenet_normalization(),
            input_height=dm.size()[-1],
            eval_transform=True)

        args.num_samples = 1
        args.maxpool1 = True
        args.first_conv = True
    else:
        raise NotImplementedError(
            "other datasets have not been implemented till now")

    backbone = SwAV(
        gpus=args.gpus,
        num_samples=args.num_samples,
        batch_size=args.batch_size,
        datamodule=dm,
        maxpool1=args.maxpool1,
        first_conv=args.first_conv,
        dataset='imagenet',
    ).load_from_checkpoint(args.ckpt_path, strict=False)

    tuner = SSLFineTuner(backbone,
                         in_features=args.in_features,
                         num_classes=dm.num_classes,
                         epochs=args.num_epochs,
                         hidden_dim=None,
                         dropout=args.dropout,
                         learning_rate=args.learning_rate,
                         weight_decay=args.weight_decay,
                         nesterov=args.nesterov,
                         scheduler_type=args.scheduler_type,
                         gamma=args.gamma,
                         final_lr=args.final_lr)

    trainer = pl.Trainer(
        gpus=args.gpus,
        precision=16,
        max_epochs=args.num_epochs,
        distributed_backend='ddp',
        sync_batchnorm=True if args.gpus > 1 else False,
    )

    trainer.fit(tuner, dm)
    trainer.test(datamodule=dm)
def cli_main():
    pl.seed_everything(1234)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # ------------
    # data
    # ------------
    dataset = MNIST("",
                    train=True,
                    download=True,
                    transform=transforms.ToTensor())
    mnist_test = MNIST("",
                       train=False,
                       download=True,
                       transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

    # ------------
    # model
    # ------------
    model = LitAutoEncoder()

    # ------------
    # logging
    # ------------
    # get run object using mlflow
    with mlflow.start_run() as run:
        experiment_id = run.info.experiment_id
        # get the experiment name
        exp_name = mlflow.get_experiment(experiment_id).name
        # get the mlflow tracking uri
        mlflow_uri = mlflow.get_tracking_uri()

        mlf_logger = MLFlowLogger(experiment_name=exp_name,
                                  tracking_uri=mlflow_uri)
        # link the mlflowlogger run ID to the azureml run ID
        mlf_logger._run_id = run.info.run_id

        # ------------
        # training
        # ------------
        trainer = pl.Trainer.from_argparse_args(args, logger=mlf_logger)
        trainer.fit(model, train_loader, val_loader)

        # ------------
        # testing
        # ------------
        result = trainer.test(test_dataloaders=test_loader)
        print(result)
Exemplo n.º 18
0
def train_treelstm(config: DictConfig):
    filter_warnings()
    seed_everything(config.seed)
    dgl.seed(config.seed)

    print_config(config, ["hydra", "log_offline"])

    data_module = JsonlDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    model: LightningModule
    if "max_types" in config and "max_type_parts" in config:
        model = TypedTreeLSTM2Seq(config, data_module.vocabulary)
    else:
        model = TreeLSTM2Seq(config, data_module.vocabulary)

    # define logger
    wandb_logger = WandbLogger(project=f"tree-lstm-{config.dataset}",
                               log_model=False,
                               offline=config.log_offline)
    wandb_logger.watch(model)
    # define model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=wandb_logger.experiment.dir,
        filename="{epoch:02d}-{val_loss:.4f}",
        period=config.save_every_epoch,
        save_top_k=-1,
    )
    upload_checkpoint_callback = UploadCheckpointCallback(
        wandb_logger.experiment.dir)
    # define early stopping callback
    early_stopping_callback = EarlyStopping(patience=config.patience,
                                            monitor="val_loss",
                                            verbose=True,
                                            mode="min")
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback("train", "val")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    trainer = Trainer(
        max_epochs=config.n_epochs,
        gradient_clip_val=config.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        log_every_n_steps=config.log_every_step,
        logger=wandb_logger,
        gpus=gpu,
        progress_bar_refresh_rate=config.progress_bar_refresh_rate,
        callbacks=[
            lr_logger,
            early_stopping_callback,
            checkpoint_callback,
            upload_checkpoint_callback,
            print_epoch_result_callback,
        ],
        resume_from_checkpoint=config.resume_checkpoint,
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test()
Exemplo n.º 19
0
    return model, trainer


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Train segmentation model')
    parser.add_argument('-c',
                        '--config',
                        help='training config file',
                        required=True)
    parser.add_argument('-t', '--train', action='store_true')
    parser.add_argument('-e', '--export', action='store_true')
    args = parser.parse_args()

    config = load_json(args.config)

    os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices
    seed = config.run.seed or random.randint(1, 10000)
    seed_everything(seed)
    print('Using manual seed: {}'.format(seed))
    print('Config: ', config)

    model, trainer = build_trainer(config, seed, args)

    if args.train:
        trainer.fit(model)

    elif args.export:
        model.export_models('./')
Exemplo n.º 20
0
from pytorch_lightning import Trainer, seed_everything
import torch
from torch.utils.data import DataLoader, Subset
from transforms.scene import (
    SeqToTensor,
    Padding_shift_ori_model,
    Augment_rotation,
    Augment_jitterring,
    Get_cat_shift_info,
)
from datasets.suncg_shift_seperate_dataset_deepsynth import SUNCG_Dataset
from separate_models.scene_shift_ori_col import scene_transformer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from utils.config import read_config

seed_everything(1)


def log_metrics(self, metrics, step=None):
    for k, v in metrics.items():
        if isinstance(v, dict):
            self.experiment.add_scalars(k, v, step)
        else:
            if isinstance(v, torch.Tensor):
                v = v.item()
            self.experiment.add_scalar(k, v, step)


def monkeypatch_tensorboardlogger(logger):
    import types
Exemplo n.º 21
0
    parser.add_argument("--gamma",
                        type=float,
                        default=0.7,
                        metavar="M",
                        help="Learning rate step gamma (default: 0.7)")
    parser.add_argument("--dry-run",
                        action="store_true",
                        default=False,
                        help="quickly check a single pass")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument("--save-model",
                        action="store_true",
                        default=False,
                        help="For Saving the current Model")
    hparams = parser.parse_args()

    seed_everything(hparams.seed)

    Lite(accelerator="cpu", devices=1).run(hparams)
Exemplo n.º 22
0
def train(config: DictConfig, resume_from_checkpoint: str = None):
    filter_warnings()
    print_config(config)
    seed_everything(config.seed)

    known_models = {
        "code2seq": get_code2seq,
        "code2class": get_code2class,
        "typed-code2seq": get_typed_code2seq
    }
    if config.name not in known_models:
        print(f"Unknown model: {config.name}, try on of {known_models.keys()}")

    vocabulary = Vocabulary.load_vocabulary(
        join(config.data_folder, config.dataset.name, config.vocabulary_name))
    model, data_module = known_models[config.name](config, vocabulary)

    # define logger
    wandb_logger = WandbLogger(project=f"{config.name}-{config.dataset.name}",
                               log_model=True,
                               offline=config.log_offline)
    wandb_logger.watch(model)
    # define model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=wandb_logger.experiment.dir,
        filename="{epoch:02d}-{val_loss:.4f}",
        period=config.save_every_epoch,
        save_top_k=-1,
    )
    upload_checkpoint_callback = UploadCheckpointCallback(
        wandb_logger.experiment.dir)
    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyper_parameters.patience,
        monitor="val_loss",
        verbose=True,
        mode="min")
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback("train", "val")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    trainer = Trainer(
        max_epochs=config.hyper_parameters.n_epochs,
        gradient_clip_val=config.hyper_parameters.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        log_every_n_steps=config.log_every_epoch,
        logger=wandb_logger,
        gpus=gpu,
        progress_bar_refresh_rate=config.progress_bar_refresh_rate,
        callbacks=[
            lr_logger,
            early_stopping_callback,
            checkpoint_callback,
            upload_checkpoint_callback,
            print_epoch_result_callback,
        ],
        resume_from_checkpoint=resume_from_checkpoint,
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test()
Exemplo n.º 23
0
import numpy as np
import pytorch_lightning as pl
import lightly
from loss import BarlowTwinsLoss

from utils import knn_predict, BenchmarkModule

num_workers = 8
max_epochs = 800
knn_k = 200
knn_t = 0.1
classes = 10
batch_size = 512
seed = 1

pl.seed_everything(seed)

# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0
device = 'cuda' if gpus else 'cpu'

# Use SimCLR augmentations, additionally, disable blur
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint

# local application imports
from ablations.PixelHNN import HNN, PixelHNN, MLP, MLPAutoencoder, PixelHNNDataset
from utils import arrange_data, from_pickle, my_collate

seed_everything(0)


class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        autoencoder = MLPAutoencoder(input_dim=2 * 32**2,
                                     hidden_dim=200,
                                     latent_dim=2,
                                     nonlinearity='relu')
        self.pixelHNN = PixelHNN(input_dim=2,
                                 hidden_dim=200,
Exemplo n.º 25
0
from multiprocessing import cpu_count

import pytorch_lightning as pl
import toml
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, ConcatDataset

from models.blocked_phononet import BlockedPhononet
from src import *

pl.seed_everything(42)

config = toml.load('carnatic.toml')
fcd_carnatic = FullChromaDataset(json_path=config['data']['metadata'],
                                 data_folder=config['data']['chroma_folder'],
                                 include_mbids=json.load(
                                     open(config['data']['limit_songs'])),
                                 carnatic=True)
train_carnatic, fcd_not_train_carnatic = fcd_carnatic.train_test_split(
    train_size=0.70)
val_carnatic, test_carnatic = fcd_not_train_carnatic.train_test_split(
    test_size=0.5)

chunk_sizes = (100, )
strides = (10, 10, 10)
chunked_data = [
    ChromaChunkDataset(train_carnatic, chunk_size=chunk_size, stride=stride)
    for chunk_size, stride in zip(chunk_sizes, strides)
]
train_carnatic = ConcatDataset(chunked_data)
Exemplo n.º 26
0
    for i, label in enumerate(words):
        submission[label] = 0.
    for i, label in enumerate(words):
        submission.loc[:,label] = np.array(predictions_proba)[:,i]

    train_batch_size,_, n_folds, img_size, n_epochs, arch = params.values()

    csv_file = f'GIZ_SIZE_{img_size}_arch_{arch}_n_folds_{n_folds}_num_epochs_{n_epochs}_train_bs_{train_batch_size}.csv'
    submission.to_csv(os.path.join(submissions_folder, csv_file), index=False)

    print(f'[INFO] Submission file save to {os.path.join(submissions_folder, csv_file)}')

if __name__ == '__main__':
    args = parser.parse_args()

    _ = seed_everything(args.seed_value)
    # data augmentations
    data_transforms = {
        'train': al.Compose([
                al.Resize(args.img_size, args.img_size),
                al.Cutout(p=.6, max_h_size=15, max_w_size=10, num_holes=4),
                al.Rotate(limit=35, p=.04),
                al.Normalize((0.1307,), (0.3081,))
        ]),

        'test': al.Compose([
                al.Resize(args.img_size, args.img_size),
                al.Cutout(p=.6, max_h_size=15, max_w_size=10, num_holes=4),
                al.Normalize((0.1307,), (0.3081,))
        ])
    }
def test_eval_logging_auto_reduce(tmpdir):
    """
    Tests that only training_step can be used
    """
    seed_everything(1234)

    class TestModel(BoringModel):
        def on_pretrain_routine_end(self) -> None:
            self.seen_vals = []
            self.manual_epoch_end_mean = None

        def on_validation_epoch_start(self) -> None:
            self.seen_vals = []

        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            self.seen_vals.append(loss)
            self.log('val_loss',
                     loss,
                     on_epoch=True,
                     on_step=True,
                     prog_bar=True)
            return {"x": loss}

        def validation_epoch_end(self, outputs) -> None:
            for passed_in, manually_tracked in zip(outputs, self.seen_vals):
                assert passed_in['x'] == manually_tracked
            self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs
                                                      ]).mean()

    model = TestModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=3,
        limit_val_batches=3,
        max_epochs=1,
        log_every_n_steps=1,
        weights_summary=None,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)

    # make sure all the metrics are available for callbacks
    manual_mean = model.manual_epoch_end_mean
    callback_metrics = set(trainer.callback_metrics.keys())
    assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'}

    # make sure values are correct
    assert trainer.logged_metrics['val_loss_epoch'] == manual_mean
    assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics[
        'val_loss_step/epoch_0']

    # make sure correct values were logged
    logged_val = trainer.dev_debugger.logged_metrics

    # 3 val batches
    assert logged_val[0]['val_loss_step/epoch_0'] == model.seen_vals[0]
    assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[1]
    assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[2]

    # epoch mean
    assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean

    # only those logged
    assert len(logged_val) == 4
Exemplo n.º 28
0
def main():

    # get configs
    experiment_type = "simclr"
    console_logger = get_console_logger(__name__)
    args = get_general_args("simclr training script.")

    train_param = edict(read_json(TRAINING_CONFIG_PATH))
    train_param = update_train_params(args, train_param)
    model_param_path = SIMCLR_HEATMAP_CONFIG if args.heatmap else SIMCLR_CONFIG
    model_param = edict(read_json(model_param_path))
    console_logger.info(f"Train parameters {pformat(train_param)}")
    seed_everything(train_param.seed)

    # data preperation
    data = get_data(Data_Set,
                    train_param,
                    sources=args.sources,
                    experiment_type=experiment_type)
    train_data_loader, val_data_loader = get_train_val_split(
        data,
        batch_size=train_param.batch_size,
        num_workers=train_param.num_workers)

    # Logger
    experiment_name = prepare_name(f"{experiment_type}_",
                                   train_param,
                                   hybrid_naming=False)
    comet_logger = CometLogger(**COMET_KWARGS, experiment_name=experiment_name)
    # model.
    model_param = update_model_params(model_param, args, len(data),
                                      train_param)
    console_logger.info(f"Model parameters {pformat(model_param)}")
    model = get_model(experiment_type="simclr",
                      heatmap_flag=args.heatmap,
                      denoiser_flag=args.denoiser)(config=model_param)

    # callbacks
    callbacks = get_callbacks(
        logging_interval=args.log_interval,
        experiment_type="simclr",
        save_top_k=args.save_top_k,
        period=args.save_period,
    )
    # trainer

    trainer = Trainer(
        accumulate_grad_batches=train_param.accumulate_grad_batches,
        gpus="0",
        logger=comet_logger,
        max_epochs=train_param.epochs,
        precision=train_param.precision,
        amp_backend="native",
        **callbacks,
    )
    trainer.logger.experiment.set_code(
        overwrite=True,
        filename=os.path.join(MASTER_THESIS_DIR, "src", "experiments",
                              "simclr_experiment.py"),
    )
    if args.meta_file is not None:
        save_experiment_key(
            experiment_name=experiment_name,
            experiment_key=trainer.logger.experiment.get_key(),
            filename=args.meta_file,
        )
    trainer.logger.experiment.log_parameters(train_param)
    trainer.logger.experiment.log_parameters(model_param)
    trainer.logger.experiment.add_tags(["pretraining", "simclr"] + args.tag)
    # training
    trainer.fit(model, train_data_loader, val_data_loader)
Exemplo n.º 29
0
        args.name = (
            f"{args.model}_{args.dataset}_l{args.layer_count}_d{args.dim_model}_seq{args.seq_len}"
            + f"_{int(time.time())%10000}"
        )

    if args.dataset in ["ednet", "ednet_medium"]:
        args.num_item = 14000
        args.num_skill = 300
    else:
        full_df = pd.read_csv(
            os.path.join("data", args.dataset, "preprocessed_data.csv"), sep="\t"
        )
        args.num_item = int(full_df["item_id"].max() + 1)
        args.num_skill = int(full_df["skill_id"].max() + 1)
    # set random seed
    pl.seed_everything(args.random_seed)

    print_args(args)
    if args.model.lower().startswith('saint'):
        model = SAINT(args)
    elif args.model.lower().startswith('sakt'):
        model = SAKT(args)
    elif args.model.lower().startswith('dkt'):
        model = DKT(args)
    else:
        raise NotImplementedError

    checkpoint_callback = ModelCheckpoint(
        monitor="val_auc",
        dirpath=f"save/{args.model}/{args.dataset}",
        filename=f"{args.name}",
def train(param):
    if not isinstance(param, dict):
        args = vars(param)
    else:
        args = param

    framework = get_class_by_name('conditioned_separation', args['model'])
    if args['spec_type'] != 'magnitude':
        args['input_channels'] = 4

    if args['resume_from_checkpoint'] is None:
        if args['seed'] is not None:
            seed_everything(args['seed'])

    model = framework(**args)

    if args['last_activation'] != 'identity' and args[
            'spec_est_mode'] != 'masking':
        warn(
            'Please check if you really want to use a mapping-based spectrogram estimation method '
            'with a final activation function. ')
    ##########################################################

    # -- checkpoint
    ckpt_path = Path(args['ckpt_root_path'])
    mkdir_if_not_exists(ckpt_path)
    ckpt_path = ckpt_path.joinpath(args['model'])
    mkdir_if_not_exists(ckpt_path)
    run_id = args['run_id']
    ckpt_path = ckpt_path.joinpath(run_id)
    mkdir_if_not_exists(ckpt_path)
    save_top_k = args['save_top_k']

    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=save_top_k,
        verbose=False,
        monitor='val_loss',
        save_last=False,
        save_weights_only=args['save_weights_only'])
    args['checkpoint_callback'] = checkpoint_callback

    # -- early stop
    patience = args['patience']
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.0,
                                        patience=patience,
                                        verbose=False)
    args['early_stop_callback'] = early_stop_callback

    if args['resume_from_checkpoint'] is not None:
        run_id = run_id + "_resume_" + args['resume_from_checkpoint']
        args['resume_from_checkpoint'] = Path(args['ckpt_root_path']).joinpath(
            args['model']).joinpath(args['run_id']).joinpath(
                args['resume_from_checkpoint'])
        args['resume_from_checkpoint'] = str(args['resume_from_checkpoint'])

    # -- logger setting
    log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft',
                                     tags=args['model'],
                                     offline=False,
                                     id=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_save_path = 'etc/lightning_logs'
        mkdir_if_not_exists(default_save_path)

    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = dict(
        (name, args[name]) for name in valid_kwargs if name in args)

    # DATASET
    ##########################################################
    data_provider = DataProvider(**args)
    ##########################################################
    # Trainer Definition

    # Trainer
    trainer = Trainer(**trainer_kwargs)
    n_fft, hop_length, num_frame = args['n_fft'], args['hop_length'], args[
        'num_frame']
    train_data_loader = data_provider.get_train_dataloader(
        n_fft, hop_length, num_frame)
    valid_data_loader = data_provider.get_valid_dataloader(
        n_fft, hop_length, num_frame)

    for key in sorted(args.keys()):
        print('{}:{}'.format(key, args[key]))

    if args['auto_lr_find']:
        lr_finder = trainer.lr_find(model,
                                    train_data_loader,
                                    valid_data_loader,
                                    early_stop_threshold=None)
        print(lr_finder.results)
        # torch.save(lr_finder.results, 'lr_result.cache')
        new_lr = lr_finder.suggestion()
        print('new_lr_suggestion:', new_lr)
        return 0

    print(model)

    trainer.fit(model, train_data_loader, valid_data_loader)

    return None