Exemplo n.º 1
0
    def ddp_train(self, process_idx, model):
        """
        Entry point for ddp

        Args:
            process_idx:
            mp_queue: multiprocessing queue
            model:

        Returns:
            Dict with evaluation results

        """
        seed = os.environ.get("PL_GLOBAL_SEED")
        if seed is not None:
            seed_everything(int(seed))

        # show progressbar only on progress_rank 0
        if (self.trainer.node_rank != 0 or process_idx != 0
            ) and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # determine which process we are and world size
        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.trainer.global_rank

        # Initialize cuda device
        self.init_device(process_idx)

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        model.trainer = self.trainer
        self.init_ddp_connection(self.trainer.global_rank,
                                 self.trainer.world_size,
                                 self.trainer.is_slurm_managing_tasks)

        if isinstance(self.ddp_plugin, RPCPlugin):
            if not self.ddp_plugin.is_main_rpc_process:
                self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
                self.ddp_plugin.exit_rpc_process()
                if self.ddp_plugin.return_after_exit_rpc_process:
                    return
            else:
                self.ddp_plugin.on_main_rpc_connection(self.trainer)

        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # on world_size=0 let everyone know training is starting
        if self.trainer.is_global_zero and not torch.distributed.is_initialized(
        ):
            log.info('-' * 100)
            log.info(f'distributed_backend={self.trainer.distributed_backend}')
            log.info(
                f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes'
            )
            log.info('-' * 100)

        # call sync_bn before .cuda(), configure_apex and configure_ddp
        if self.trainer.sync_batchnorm:
            model = self.configure_sync_batchnorm(model)

        # move the model to the correct device
        self.model_to_device(model)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        self.setup_optimizers(model)

        # set model properties before going into wrapper
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # 16-bit
        model = self.trainer.precision_connector.connect(model)

        self.trainer.convert_to_lightning_optimizers()

        # device ids change depending on the DDP setup
        device_ids = self.get_device_ids()

        # allow user to configure ddp
        model = self.configure_ddp(model, device_ids)

        # set up training routine
        self.barrier('ddp_setup')
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        # clean up memory
        torch.cuda.empty_cache()

        return results
Exemplo n.º 2
0
def test_correct_seed_with_environment_variable():
    """
    Ensure that the PL_GLOBAL_SEED environment is read
    """
    assert seed_utils.seed_everything() == 2020
Exemplo n.º 3
0
    def ddp_train_tmp(self,
                      process_idx,
                      mp_queue,
                      model,
                      is_master=False,
                      proc_offset=0):
        """
        Entry point for ddp

        Args:
            process_idx:
            mp_queue: multiprocessing queue
            model:

        Returns:

        """
        seed = os.environ.get("PL_GLOBAL_SEED")
        if seed is not None:
            seed_everything(int(seed))

        # offset the process id if requested
        process_idx = process_idx + proc_offset

        # show progressbar only on progress_rank 0
        if (self.trainer.node_rank != 0 or process_idx != 0
            ) and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

        # determine which process we are and world size
        self.set_world_ranks(process_idx)

        # set warning rank
        rank_zero_only.rank = self.trainer.global_rank

        # set up server using proc 0's ip address
        # try to init for 20 times at max in case ports are taken
        # where to store ip_table
        model.trainer = self.trainer
        model.init_ddp_connection(self.trainer.global_rank,
                                  self.trainer.world_size,
                                  self.trainer.is_slurm_managing_tasks)

        # call setup after the ddp process has connected
        self.trainer.call_setup_hook(model)

        # on world_size=0 let everyone know training is starting
        if self.trainer.is_global_zero:
            log.info('-' * 100)
            log.info(f'distributed_backend={self.trainer.distributed_backend}')
            log.info(
                f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes'
            )
            log.info('-' * 100)

        # call sync_bn before .cuda(), configure_apex and configure_ddp
        if self.trainer.sync_batchnorm:
            model = model.configure_sync_batchnorm(model)

        # move the model to the correct device
        self.model_to_device(model, process_idx, is_master)

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        self.setup_optimizers(model)

        # set model properties before going into wrapper
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # 16-bit
        model = self.trainer.precision_connector.connect(model)

        # device ids change depending on the DDP setup
        device_ids = self.get_device_ids()

        # allow user to configure ddp
        model = model.configure_ddp(model, device_ids)

        # set up training routine
        self.trainer.train_loop.setup_training(model)

        # train or test
        results = self.train_or_test()

        # get original model
        model = self.trainer.get_model()

        # persist info in ddp_spawn
        self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

        # clean up memory
        torch.cuda.empty_cache()

        if self.trainer.global_rank == 0:
            return results
Exemplo n.º 4
0
    ckpt = glob.glob(
        "{}/lightning_logs/version_0/checkpoints/epoch*.ckpt".format(
            model_dir))[0]

    m = UNet.load_from_checkpoint(ckpt)
    m.freeze()

    summary(m, (3, img_size, img_size), device='cpu')

    print("Loading checkpoint: {}".format(ckpt))

    return m


if __name__ == '__main__':
    seed_everything(seed=45)

    model_dir = ""

    while not os.path.isdir(model_dir):
        model_dir = input("Enter model directory: ")

        # where to store model params
        model_dir = "{}/{}".format(model_output_parent, model_dir)

    datasets = get_datasets(_same_image_all_channels=same_image_all_channels,
                            model_dir=model_dir,
                            new_ds_split=False,
                            train_list="{}/train.txt".format(model_dir),
                            val_list="{}/val.txt".format(model_dir),
                            test_list="{}/test.txt".format(model_dir))
Exemplo n.º 5
0
import copy
from sklearn.preprocessing import StandardScaler

import torch
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything

# Define seed

# Set seeds
seed=1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
seed_everything(seed)

# Tricky: https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
# torch.use_deterministic_algorithms(True)

# Test randomness
print(f"\t- [INFO]: Testing random seed ({seed}):")
print(f"\t\t- random: {random.random()}")
print(f"\t\t- numpy: {np.random.rand(1)}")
print(f"\t\t- torch: {torch.rand(1)}")


def create_big_model(ds_small, ds_big):
    max_length = 100
    src_vocab_small = Vocabulary(max_tokens=max_length).build_from_ds(ds=ds_small, lang=ds_small.src_lang)
    trg_vocab_small = Vocabulary(max_tokens=max_length).build_from_ds(ds=ds_small, lang=ds_small.trg_lang)
Exemplo n.º 6
0
from src.model.efficient_el import EfficientEL

if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument("--dirpath", type=str, default="models")
    parser.add_argument("--save_top_k", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)

    parser = EfficientEL.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    args, _ = parser.parse_known_args()
    pprint(args.__dict__)

    seed_everything(seed=args.seed)

    logger = TensorBoardLogger(args.dirpath, name=None)

    callbacks = [
        ModelCheckpoint(
            mode="max",
            monitor="micro_f1",
            dirpath=os.path.join(logger.log_dir, "checkpoints"),
            save_top_k=args.save_top_k,
            filename="model-epoch={epoch:02d}-micro_f1={micro_f1:.4f}-ed_micro_f1={ed_micro_f1:.4f}",
        ),
        LearningRateMonitor(
            logging_interval="step",
        ),
    ]
Exemplo n.º 7
0
def execute_single_run(params, hparams, log_dirs, experiment: bool):
    """
    :param params:     [argparse.Namespace] attr: experiment_name, run_name, pretrained_model_name, dataset_name, ..
    :param hparams:    [argparse.Namespace] attr: batch_size, max_seq_length, max_epochs, lr_*
    :param log_dirs:   [argparse.Namespace] attr: mlflow, tensorboard
    :param experiment: [bool] whether run is part of an experiment w/ multiple runs
    :return: -
    """
    # seed
    seed = params.seed + int(
        params.run_name_nr.split("-")[1]
    )  # run_name_nr = e.g. 'runA-1', 'runA-2', 'runB-1', ..
    seed_everything(seed)

    default_logger = DefaultLogger(
        __file__, log_file=log_dirs.log_file, level=params.logging_level
    )  # python logging
    default_logger.clear()

    print_run_information(params, hparams, default_logger, seed)

    lightning_hparams = unify_parameters(params, hparams, log_dirs, experiment)

    tb_logger = logging_start(params, log_dirs)
    with mlflow.start_run(run_name=params.run_name_nr, nested=experiment):

        model = NerModelTrain(lightning_hparams)
        callbacks = get_callbacks(params, hparams, log_dirs)

        trainer = Trainer(
            max_epochs=hparams.max_epochs,
            gpus=torch.cuda.device_count() if params.device.type == "cuda" else None,
            precision=16 if (params.fp16 and params.device.type == "cuda") else 32,
            # amp_level="O1",
            logger=tb_logger,
            callbacks=list(callbacks),
        )
        trainer.fit(model)
        callback_info = get_callback_info(callbacks, params, hparams)

        default_logger.log_info(
            f"---\n"
            f"LOAD BEST CHECKPOINT {callback_info['checkpoint_best']}\n"
            f"FOR TESTING AND DETAILED RESULTS\n"
            f"---"
        )
        model_best = NerModelTrain.load_from_checkpoint(
            checkpoint_path=callback_info["checkpoint_best"]
        )
        trainer_best = Trainer(
            gpus=torch.cuda.device_count() if params.device.type == "cuda" else None,
            precision=16 if (params.fp16 and params.device.type == "cuda") else 32,
            logger=tb_logger,
        )
        trainer_best.validate(model_best)
        trainer_best.test(model_best)

        # logging end
        logging_end(
            tb_logger, callback_info, hparams, model, model_best, default_logger
        )

        # remove checkpoint
        if params.checkpoints is False:
            remove_checkpoint(callback_info["checkpoint_best"], default_logger)
Exemplo n.º 8
0
def test_invalid_seed():
    """Ensure that we still fix the seed even if an invalid seed is given."""
    with pytest.warns(UserWarning, match="Invalid seed found"):
        seed = seed_utils.seed_everything()
    assert seed == 123
Exemplo n.º 9
0
def test_out_of_bounds_seed(seed):
    """Ensure that we still fix the seed even if an out-of-bounds seed is given."""
    with pytest.warns(UserWarning, match="is not in bounds"):
        actual = seed_utils.seed_everything(seed)
    assert actual == 123
Exemplo n.º 10
0
    def fit(
        self,
        train: pd.DataFrame,
        validation: Optional[pd.DataFrame] = None,
        test: Optional[pd.DataFrame] = None,
        loss: Optional[torch.nn.Module] = None,
        metrics: Optional[List[Callable]] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        optimizer_params: Dict = {},
        train_sampler: Optional[torch.utils.data.Sampler] = None,
        target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        reset: bool = False,
        seed: Optional[int] = None,
        trained_backbone: Optional[pl.LightningModule] = None,
        callbacks: Optional[List[pl.Callback]] = None,
    ) -> None:
        """The fit method which takes in the data and triggers the training

        Args:
            train (pd.DataFrame): Training Dataframe

            validation (Optional[pd.DataFrame], optional): If provided, will use this dataframe as the validation while training.
                Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

            test (Optional[pd.DataFrame], optional): If provided, will use as the hold-out data,
                which you'll be able to check performance after the model is trained. Defaults to None.

            loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

            metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
                signature metric_fn(y_hat, y) and works on torch tensor inputs

            optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements for standard PyToch optimizers.
                This should be the Class and not the initialized object

            optimizer_params (Optional[Dict], optional): The parmeters to initialize the custom optimizer.

            train_sampler (Optional[torch.utils.data.Sampler], optional): Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

            target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional): If provided, applies the transform to the target before modelling
                and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
                a tuple of callables (transform_func, inverse_transform_func)

            max_epochs (Optional[int]): Overwrite maximum number of epochs to be run

            min_epochs (Optional[int]): Overwrite minimum number of epochs to be run

            reset: (bool): Flag to reset the model and train again from scratch

            seed: (int): If you have to override the default seed set as part of of ModelConfig

            trained_backbone (pl.LightningModule): this module contains the weights for a pretrained backbone

            callbacks (Optional[List[pl.Callback]], optional): Custom callbacks to be used during training.
        """
        seed_everything(seed if seed is not None else self.config.seed)
        train_loader, val_loader = self._pre_fit(
            train,
            validation,
            test,
            loss,
            metrics,
            optimizer,
            optimizer_params,
            train_sampler,
            target_transform,
            max_epochs,
            min_epochs,
            reset,
            trained_backbone,
            callbacks,
        )
        self.model.train()
        if self.config.auto_lr_find and (not self.config.fast_dev_run):
            self.trainer.tune(self.model, train_loader, val_loader)
            # Parameters in models needs to be initialized again after LR find
            self.model.data_aware_initialization(self.datamodule)
        self.model.train()
        self.trainer.fit(self.model, train_loader, val_loader)
        logger.info("Training the model completed...")
        if self.config.load_best:
            self.load_best_model()
Exemplo n.º 11
0
from modules import SimpleModule, ShopeeDataset
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(1)

import argparse
import os
import numpy as np
import torch
from torch import nn
from PIL import Image

from sklearn.model_selection import train_test_split
import pandas as pd
import pytorch_lightning as pl

from torchvision import models
from torch.utils.data import DataLoader
import datetime


root_folder = os.path.dirname(__file__)
default_train = "../data/train.csv"
default_train = os.path.join(root_folder, default_train)

default_test = "../data/test.csv"
default_test = os.path.join(root_folder, default_test)


def parse_args():
    """ Load hyper parameters helper """
Exemplo n.º 12
0
def hyper_tuner(config, NetClass, dataset):

    learning_rate = config["learning_rate"]
    batch_size = config["batch_size"]

    if "structure" in config and config["structure"] == "mpl":
        structure = {"mpl": "mpl"}
    else:
        structure = {
            "lstm": {
                "bidirectional": config["bidirectional"],
                "hidden_dim": config["hidden_dim"],
                "num_layers": config["num_layers"]
            }
        }

    monitor = config["monitor"]
    shared_output_size = config["shared_output_size"]
    opt_step_size = config["opt_step_size"]
    weight_decay = config["weight_decay"]
    dropout_input_layers = config["dropout_input_layers"]
    dropout_inner_layers = config["dropout_inner_layers"]

    sleep_metrics = ['sleepEfficiency']
    loss_method = 'equal'

    X, Y = get_data(dataset)

    train = DataLoader(myXYDataset(X["train"], Y["train"]),
                       batch_size=batch_size,
                       shuffle=True,
                       drop_last=True,
                       num_workers=8)
    val = DataLoader(myXYDataset(X["val"], Y["val"]),
                     batch_size=batch_size,
                     shuffle=False,
                     drop_last=True,
                     num_workers=8)
    test = DataLoader(myXYDataset(X["test"], Y["test"]),
                      batch_size=batch_size,
                      shuffle=False,
                      drop_last=True,
                      num_workers=8)

    results = []

    seed.seed_everything(42)

    path_ckps = "./lightning_logs/test/"

    if monitor == "mcc":
        early_stop_callback = EarlyStopping(min_delta=0.00,
                                            verbose=False,
                                            monitor='mcc',
                                            mode='max',
                                            patience=3)
        ckp = ModelCheckpoint(filename=path_ckps +
                              "{epoch:03d}-{loss:.3f}-{mcc:.3f}",
                              save_top_k=1,
                              verbose=False,
                              prefix="",
                              monitor="mcc",
                              mode="max")
    else:
        early_stop_callback = EarlyStopping(min_delta=0.00,
                                            verbose=False,
                                            monitor='loss',
                                            mode='min',
                                            patience=3)
        ckp = ModelCheckpoint(filename=path_ckps +
                              "{epoch:03d}-{loss:.3f}-{mcc:.3f}",
                              save_top_k=1,
                              verbose=False,
                              prefix="",
                              monitor="loss",
                              mode="min")

    hparams = Namespace(
        batch_size=batch_size,
        shared_output_size=shared_output_size,
        sleep_metrics=sleep_metrics,
        dropout_input_layers=dropout_input_layers,
        dropout_inner_layers=dropout_inner_layers,
        structure=structure,
        #
        # Optmizer configs
        #
        opt_learning_rate=learning_rate,
        opt_weight_decay=weight_decay,
        opt_step_size=opt_step_size,
        opt_gamma=0.5,
        #
        # Loss combination method
        #
        loss_method=loss_method,  # Options: equal, alex, dwa
        #
        # Output layer
        #
        output_strategy="linear",  # Options: attention, linear
        dataset=dataset,
        monitor=monitor,
    )

    model = NetClass(hparams)
    model.double()

    tune_metrics = {
        "loss": "loss",
        "mcc": "mcc",
        "acc": "acc",
        "macroF1": "macroF1"
    }
    tune_cb = TuneReportCallback(tune_metrics, on="validation_end")

    trainer = Trainer(gpus=0,
                      min_epochs=2,
                      max_epochs=100,
                      callbacks=[early_stop_callback, ckp, tune_cb])
    trainer.fit(model, train, val)
Exemplo n.º 13
0
            if batch_idx % self.hparams.sync_batches == 0:
                self.model.alpha_sync(self.hparams.polyak)

            return actor_loss_v

    def validation_step(self, batch, batch_idx):
        to_log = dict()
        for k, v in batch.items():
            to_log[k] = v.detach().cpu().numpy()
        to_log['epoch_nr'] = int(self.current_epoch)
        if self.logger is not None:
            self.logger.experiment.log(to_log)


if __name__ == '__main__':
    mp.set_start_method('spawn')

    hparams = get_args()

    if hparams.debug:
        hparams.logger = None
        hparams.profiler = SimpleProfiler()
    else:
        hparams.logger = WandbLogger(project=hparams.project)

    seed_everything(hparams.seed)
    her = HER(hparams)
    trainer = pl.Trainer.from_argparse_args(hparams)
    trainer.callbacks.append(SpawnCallback())
    trainer.fit(her)
Exemplo n.º 14
0
from torch_geometric.nn import GCNConv, GATConv, GCN2Conv
import torch
from torch_geometric.data import Data
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch_geometric.transforms as T

from pytorch_lightning.utilities.seed import seed_everything
seed_everything(seed=0)


############################# GCN ###########################
class GCN(pl.LightningModule):
    def __init__(self,
                 num_of_features,
                 hid_size,
                 num_of_classes,
                 activation=F.relu,
                 dropout=0.6):
        super(GCN, self).__init__()
        self._layer1 = GCNConv(num_of_features, hid_size)
        self._activation = activation
        self._layer2 = GCNConv(hid_size, num_of_classes)
        self._dropout = dropout

    def forward(self, data: Data):
        x = self._layer1(data.x, data.edge_index)
        z = self._activation(x)
        z = F.dropout(z, self._dropout)
Exemplo n.º 15
0
    total_dev_acc = 0
    total_samples = 0
    len_metrics = len(results.keys())
    for k in results.keys():
        devacc = results[k]['devacc'] * results[k]['ndev']
        total_samples += results[k]['ndev']
        total_dev_acc += devacc
    len_metrics *= total_samples
    return total_dev_acc / total_samples


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    # encoders
    parser.add_argument("--model", default='awe', type=str)
    parser.add_argument("--save_results",
                        default='results_senteval/',
                        type=str)
    parser.add_argument("--checkpoint_path", default='', type=str)
    parser.add_argument('-p', "--prototype", action='store_true')
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--batch", default=64, type=int)

    args = parser.parse_args()
    wandb.init(project="atcs-seneval", config=args)
    wandb.log({"model_name": args.model})
    seed.seed_everything(args.seed)
    args.checkpoint_path = get_checkpoint_path(args)
    run_seneval(args)
Exemplo n.º 16
0
        logger=[logger, wandb_logger],
        gpus=1,
        num_sanity_val_steps=0,
        auto_lr_find=True,
    )
    trainer.tune(model)
    trainer.fit(model)


def get_args():
    parser = ArgumentParser()
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=32)

    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--model', type=str, default='efficientnet_b1')

    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--n_workers', type=int, default=24)
    parser.add_argument('--data_dir',
                        type=str,
                        required=True,
                        help='dir of data to train on (cell-tiles)')
    return parser.parse_args()


if __name__ == '__main__':
    seed_everything(144)
    args = get_args()
    main(args)
Exemplo n.º 17
0
    def __init__(
        self,
        model_class: Type[LightningModule],
        datamodule_class: Type[LightningDataModule] = None,
        save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
        trainer_class: Type[Trainer] = Trainer,
        trainer_defaults: Dict[str, Any] = None,
        seed_everything_default: int = None,
        description: str = 'pytorch-lightning trainer command line tool',
        env_prefix: str = 'PL',
        env_parse: bool = False,
        parser_kwargs: Dict[str, Any] = None,
        subclass_mode_model: bool = False,
        subclass_mode_data: bool = False
    ) -> None:
        """
        Receives as input pytorch-lightning classes, which are instantiated
        using a parsed configuration file and/or command line args and then runs
        trainer.fit. Parsing of configuration from environment variables can
        be enabled by setting ``env_parse=True``. A full configuration yaml would
        be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from
        variables named for example ``PL_TRAINER__MAX_EPOCHS``.

        Example, first implement the ``trainer.py`` tool as::

            from mymodels import MyModel
            from pytorch_lightning.utilities.cli import LightningCLI
            LightningCLI(MyModel)

        Then in a shell, run the tool with the desired configuration::

            $ python trainer.py --print_config > config.yaml
            $ nano config.yaml  # modify the config as desired
            $ python trainer.py --cfg config.yaml

        .. warning:: ``LightningCLI`` is in beta and subject to change.

        Args:
            model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
            datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
            save_config_callback: A callback class to save the training config.
            trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
            trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
            seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
                seed argument.
            description: Description of the tool shown when running ``--help``.
            env_prefix: Prefix for environment variables.
            env_parse: Whether environment variable parsing is enabled.
            parser_kwargs: Additional arguments to instantiate LightningArgumentParser.
            subclass_mode_model: Whether model can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
            subclass_mode_data: Whether datamodule can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
        """
        assert issubclass(trainer_class, Trainer)
        assert issubclass(model_class, LightningModule)
        if datamodule_class is not None:
            assert issubclass(datamodule_class, LightningDataModule)
        self.model_class = model_class
        self.datamodule_class = datamodule_class
        self.save_config_callback = save_config_callback
        self.trainer_class = trainer_class
        self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
        self.seed_everything_default = seed_everything_default
        self.subclass_mode_model = subclass_mode_model
        self.subclass_mode_data = subclass_mode_data
        self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs
        self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse})

        self.init_parser()
        self.add_core_arguments_to_parser()
        self.add_arguments_to_parser(self.parser)
        self.parse_arguments()
        if self.config['seed_everything'] is not None:
            seed_everything(self.config['seed_everything'], workers=True)
        self.before_instantiate_classes()
        self.instantiate_classes()
        self.prepare_fit_kwargs()
        self.before_fit()
        self.fit()
        self.after_fit()
Exemplo n.º 18
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline. None = disable.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        #rank                    = 0,        # Rank of the current process in [0, num_gpus[.
    batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=None,  # EMA ramp-up coefficient.
        G_reg_interval=4,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        augment_p=0,  # Initial value of augmentation probability.
        ada_target=None,  # ADA target value. None = fixed p.
        ada_interval=4,  # How often to perform ADA adjustment?
        ada_kimg=500,  # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        allow_tf32=False,  # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
):
    # Initialize.
    start_time = time.time()
    #device = torch.device('cuda', rank)
    #np.random.seed(random_seed * num_gpus + rank)
    #torch.manual_seed(random_seed * num_gpus + rank)
    #torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
    seed_everything(random_seed)
    torch.backends.cuda.matmul.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for matmul
    torch.backends.cudnn.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for convolutions
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.

    # Load training set.
    # if rank == 0:
    #     print('Loading training set...')
    # training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
    # training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
    # training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
    # if rank == 0:
    #     print()
    #     print('Num images: ', len(training_set))
    #     print('Image shape:', training_set.image_shape)
    #     print('Label shape:', training_set.label_shape)
    #     print()

    # Construct networks.
    # if rank == 0:
    #     print('Constructing networks...')
    training_set_pl = StyleGANDataModule(batch_gpu, training_set_kwargs,
                                         data_loader_kwargs)
    training_set = training_set_pl.training_set

    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs)  # subclass of torch.nn.Module
    # # Resume from existing pickle.
    # if (resume_pkl is not None) and (rank == 0):
    #     print(f'Resuming from "{resume_pkl}"')
    #     with dnnlib.util.open_url(resume_pkl) as f:
    #         resume_data = legacy.load_network_pkl(f)
    #     for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
    #         misc.copy_params_and_buffers(resume_data[name], module, require_all=False)

    # # Print network summary tables.
    # if rank == 0:
    #     z = torch.empty([batch_gpu, G.z_dim], device=device)
    #     c = torch.empty([batch_gpu, G.c_dim], device=device)
    #     img = misc.print_module_summary(G, [z, c])
    #     misc.print_module_summary(D, [img, c])

    # Setup augmentation.
    # if rank == 0:
    #     print('Setting up augmentation...')
    augment_pipe = None
    ada_stats = None
    if (augment_kwargs is not None) and (augment_p > 0
                                         or ada_target is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(
            **augment_kwargs)  # subclass of torch.nn.Module
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        # if ada_target is not None:
        #     ada_stats = training_stats.Collector(regex='Loss/signs/real')

    fid50k = FID(max_real=None, num_gen=50000)
    ema_kimg /= num_gpus
    ada_kimg /= num_gpus
    kimg_per_tick /= num_gpus

    gpu_stats = GPUStatsMonitor(intra_step_time=True)

    net = StyleGAN2(G=G,
                    D=D,
                    G_opt_kwargs=G_opt_kwargs,
                    D_opt_kwargs=D_opt_kwargs,
                    augment_pipe=augment_pipe,
                    datamodule=training_set_pl,
                    G_reg_interval=G_reg_interval,
                    D_reg_interval=D_reg_interval,
                    ema_kimg=ema_kimg,
                    ema_rampup=ema_rampup,
                    ada_target=ada_target,
                    ada_interval=ada_interval,
                    ada_kimg=ada_kimg,
                    metrics=[fid50k],
                    kimg_per_tick=kimg_per_tick,
                    random_seed=random_seed,
                    **loss_kwargs)

    trainer = pl.Trainer(gpus=num_gpus,
                         accelerator='ddp',
                         weights_summary='full',
                         fast_dev_run=10,
                         benchmark=cudnn_benchmark,
                         max_steps=total_kimg // (batch_size) * 1000,
                         plugins=[
                             DDPPlugin(broadcast_buffers=False,
                                       find_unused_parameters=True)
                         ],
                         callbacks=[gpu_stats],
                         accumulate_grad_batches=num_gpus)
    trainer.fit(net, datamodule=training_set_pl)