def ddp_train(self, process_idx, model): """ Entry point for ddp Args: process_idx: mp_queue: multiprocessing queue model: Returns: Dict with evaluation results """ seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: seed_everything(int(seed)) # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0 ) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # determine which process we are and world size self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank # Initialize cuda device self.init_device(process_idx) # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer self.init_ddp_connection(self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks) if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() if self.ddp_plugin.return_after_exit_rpc_process: return else: self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero and not torch.distributed.is_initialized( ): log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info( f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes' ) log.info('-' * 100) # call sync_bn before .cuda(), configure_apex and configure_ddp if self.trainer.sync_batchnorm: model = self.configure_sync_batchnorm(model) # move the model to the correct device self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) # 16-bit model = self.trainer.precision_connector.connect(model) self.trainer.convert_to_lightning_optimizers() # device ids change depending on the DDP setup device_ids = self.get_device_ids() # allow user to configure ddp model = self.configure_ddp(model, device_ids) # set up training routine self.barrier('ddp_setup') self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() # clean up memory torch.cuda.empty_cache() return results
def test_correct_seed_with_environment_variable(): """ Ensure that the PL_GLOBAL_SEED environment is read """ assert seed_utils.seed_everything() == 2020
def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp Args: process_idx: mp_queue: multiprocessing queue model: Returns: """ seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: seed_everything(int(seed)) # offset the process id if requested process_idx = process_idx + proc_offset # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0 ) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # determine which process we are and world size self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer model.init_ddp_connection(self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info( f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes' ) log.info('-' * 100) # call sync_bn before .cuda(), configure_apex and configure_ddp if self.trainer.sync_batchnorm: model = model.configure_sync_batchnorm(model) # move the model to the correct device self.model_to_device(model, process_idx, is_master) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) # 16-bit model = self.trainer.precision_connector.connect(model) # device ids change depending on the DDP setup device_ids = self.get_device_ids() # allow user to configure ddp model = model.configure_ddp(model, device_ids) # set up training routine self.trainer.train_loop.setup_training(model) # train or test results = self.train_or_test() # get original model model = self.trainer.get_model() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) # clean up memory torch.cuda.empty_cache() if self.trainer.global_rank == 0: return results
ckpt = glob.glob( "{}/lightning_logs/version_0/checkpoints/epoch*.ckpt".format( model_dir))[0] m = UNet.load_from_checkpoint(ckpt) m.freeze() summary(m, (3, img_size, img_size), device='cpu') print("Loading checkpoint: {}".format(ckpt)) return m if __name__ == '__main__': seed_everything(seed=45) model_dir = "" while not os.path.isdir(model_dir): model_dir = input("Enter model directory: ") # where to store model params model_dir = "{}/{}".format(model_output_parent, model_dir) datasets = get_datasets(_same_image_all_channels=same_image_all_channels, model_dir=model_dir, new_ds_split=False, train_list="{}/train.txt".format(model_dir), val_list="{}/val.txt".format(model_dir), test_list="{}/test.txt".format(model_dir))
import copy from sklearn.preprocessing import StandardScaler import torch import random import numpy as np from pytorch_lightning.utilities.seed import seed_everything # Define seed # Set seeds seed=1234 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) seed_everything(seed) # Tricky: https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html # torch.use_deterministic_algorithms(True) # Test randomness print(f"\t- [INFO]: Testing random seed ({seed}):") print(f"\t\t- random: {random.random()}") print(f"\t\t- numpy: {np.random.rand(1)}") print(f"\t\t- torch: {torch.rand(1)}") def create_big_model(ds_small, ds_big): max_length = 100 src_vocab_small = Vocabulary(max_tokens=max_length).build_from_ds(ds=ds_small, lang=ds_small.src_lang) trg_vocab_small = Vocabulary(max_tokens=max_length).build_from_ds(ds=ds_small, lang=ds_small.trg_lang)
from src.model.efficient_el import EfficientEL if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--dirpath", type=str, default="models") parser.add_argument("--save_top_k", type=int, default=10) parser.add_argument("--seed", type=int, default=0) parser = EfficientEL.add_model_specific_args(parser) parser = Trainer.add_argparse_args(parser) args, _ = parser.parse_known_args() pprint(args.__dict__) seed_everything(seed=args.seed) logger = TensorBoardLogger(args.dirpath, name=None) callbacks = [ ModelCheckpoint( mode="max", monitor="micro_f1", dirpath=os.path.join(logger.log_dir, "checkpoints"), save_top_k=args.save_top_k, filename="model-epoch={epoch:02d}-micro_f1={micro_f1:.4f}-ed_micro_f1={ed_micro_f1:.4f}", ), LearningRateMonitor( logging_interval="step", ), ]
def execute_single_run(params, hparams, log_dirs, experiment: bool): """ :param params: [argparse.Namespace] attr: experiment_name, run_name, pretrained_model_name, dataset_name, .. :param hparams: [argparse.Namespace] attr: batch_size, max_seq_length, max_epochs, lr_* :param log_dirs: [argparse.Namespace] attr: mlflow, tensorboard :param experiment: [bool] whether run is part of an experiment w/ multiple runs :return: - """ # seed seed = params.seed + int( params.run_name_nr.split("-")[1] ) # run_name_nr = e.g. 'runA-1', 'runA-2', 'runB-1', .. seed_everything(seed) default_logger = DefaultLogger( __file__, log_file=log_dirs.log_file, level=params.logging_level ) # python logging default_logger.clear() print_run_information(params, hparams, default_logger, seed) lightning_hparams = unify_parameters(params, hparams, log_dirs, experiment) tb_logger = logging_start(params, log_dirs) with mlflow.start_run(run_name=params.run_name_nr, nested=experiment): model = NerModelTrain(lightning_hparams) callbacks = get_callbacks(params, hparams, log_dirs) trainer = Trainer( max_epochs=hparams.max_epochs, gpus=torch.cuda.device_count() if params.device.type == "cuda" else None, precision=16 if (params.fp16 and params.device.type == "cuda") else 32, # amp_level="O1", logger=tb_logger, callbacks=list(callbacks), ) trainer.fit(model) callback_info = get_callback_info(callbacks, params, hparams) default_logger.log_info( f"---\n" f"LOAD BEST CHECKPOINT {callback_info['checkpoint_best']}\n" f"FOR TESTING AND DETAILED RESULTS\n" f"---" ) model_best = NerModelTrain.load_from_checkpoint( checkpoint_path=callback_info["checkpoint_best"] ) trainer_best = Trainer( gpus=torch.cuda.device_count() if params.device.type == "cuda" else None, precision=16 if (params.fp16 and params.device.type == "cuda") else 32, logger=tb_logger, ) trainer_best.validate(model_best) trainer_best.test(model_best) # logging end logging_end( tb_logger, callback_info, hparams, model, model_best, default_logger ) # remove checkpoint if params.checkpoints is False: remove_checkpoint(callback_info["checkpoint_best"], default_logger)
def test_invalid_seed(): """Ensure that we still fix the seed even if an invalid seed is given.""" with pytest.warns(UserWarning, match="Invalid seed found"): seed = seed_utils.seed_everything() assert seed == 123
def test_out_of_bounds_seed(seed): """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" with pytest.warns(UserWarning, match="is not in bounds"): actual = seed_utils.seed_everything(seed) assert actual == 123
def fit( self, train: pd.DataFrame, validation: Optional[pd.DataFrame] = None, test: Optional[pd.DataFrame] = None, loss: Optional[torch.nn.Module] = None, metrics: Optional[List[Callable]] = None, optimizer: Optional[torch.optim.Optimizer] = None, optimizer_params: Dict = {}, train_sampler: Optional[torch.utils.data.Sampler] = None, target_transform: Optional[Union[TransformerMixin, Tuple]] = None, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, reset: bool = False, seed: Optional[int] = None, trained_backbone: Optional[pl.LightningModule] = None, callbacks: Optional[List[pl.Callback]] = None, ) -> None: """The fit method which takes in the data and triggers the training Args: train (pd.DataFrame): Training Dataframe validation (Optional[pd.DataFrame], optional): If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None. test (Optional[pd.DataFrame], optional): If provided, will use as the hold-out data, which you'll be able to check performance after the model is trained. Defaults to None. loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the signature metric_fn(y_hat, y) and works on torch tensor inputs optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements for standard PyToch optimizers. This should be the Class and not the initialized object optimizer_params (Optional[Dict], optional): The parmeters to initialize the custom optimizer. train_sampler (Optional[torch.utils.data.Sampler], optional): Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional): If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func) max_epochs (Optional[int]): Overwrite maximum number of epochs to be run min_epochs (Optional[int]): Overwrite minimum number of epochs to be run reset: (bool): Flag to reset the model and train again from scratch seed: (int): If you have to override the default seed set as part of of ModelConfig trained_backbone (pl.LightningModule): this module contains the weights for a pretrained backbone callbacks (Optional[List[pl.Callback]], optional): Custom callbacks to be used during training. """ seed_everything(seed if seed is not None else self.config.seed) train_loader, val_loader = self._pre_fit( train, validation, test, loss, metrics, optimizer, optimizer_params, train_sampler, target_transform, max_epochs, min_epochs, reset, trained_backbone, callbacks, ) self.model.train() if self.config.auto_lr_find and (not self.config.fast_dev_run): self.trainer.tune(self.model, train_loader, val_loader) # Parameters in models needs to be initialized again after LR find self.model.data_aware_initialization(self.datamodule) self.model.train() self.trainer.fit(self.model, train_loader, val_loader) logger.info("Training the model completed...") if self.config.load_best: self.load_best_model()
from modules import SimpleModule, ShopeeDataset from pytorch_lightning.utilities.seed import seed_everything seed_everything(1) import argparse import os import numpy as np import torch from torch import nn from PIL import Image from sklearn.model_selection import train_test_split import pandas as pd import pytorch_lightning as pl from torchvision import models from torch.utils.data import DataLoader import datetime root_folder = os.path.dirname(__file__) default_train = "../data/train.csv" default_train = os.path.join(root_folder, default_train) default_test = "../data/test.csv" default_test = os.path.join(root_folder, default_test) def parse_args(): """ Load hyper parameters helper """
def hyper_tuner(config, NetClass, dataset): learning_rate = config["learning_rate"] batch_size = config["batch_size"] if "structure" in config and config["structure"] == "mpl": structure = {"mpl": "mpl"} else: structure = { "lstm": { "bidirectional": config["bidirectional"], "hidden_dim": config["hidden_dim"], "num_layers": config["num_layers"] } } monitor = config["monitor"] shared_output_size = config["shared_output_size"] opt_step_size = config["opt_step_size"] weight_decay = config["weight_decay"] dropout_input_layers = config["dropout_input_layers"] dropout_inner_layers = config["dropout_inner_layers"] sleep_metrics = ['sleepEfficiency'] loss_method = 'equal' X, Y = get_data(dataset) train = DataLoader(myXYDataset(X["train"], Y["train"]), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8) val = DataLoader(myXYDataset(X["val"], Y["val"]), batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8) test = DataLoader(myXYDataset(X["test"], Y["test"]), batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8) results = [] seed.seed_everything(42) path_ckps = "./lightning_logs/test/" if monitor == "mcc": early_stop_callback = EarlyStopping(min_delta=0.00, verbose=False, monitor='mcc', mode='max', patience=3) ckp = ModelCheckpoint(filename=path_ckps + "{epoch:03d}-{loss:.3f}-{mcc:.3f}", save_top_k=1, verbose=False, prefix="", monitor="mcc", mode="max") else: early_stop_callback = EarlyStopping(min_delta=0.00, verbose=False, monitor='loss', mode='min', patience=3) ckp = ModelCheckpoint(filename=path_ckps + "{epoch:03d}-{loss:.3f}-{mcc:.3f}", save_top_k=1, verbose=False, prefix="", monitor="loss", mode="min") hparams = Namespace( batch_size=batch_size, shared_output_size=shared_output_size, sleep_metrics=sleep_metrics, dropout_input_layers=dropout_input_layers, dropout_inner_layers=dropout_inner_layers, structure=structure, # # Optmizer configs # opt_learning_rate=learning_rate, opt_weight_decay=weight_decay, opt_step_size=opt_step_size, opt_gamma=0.5, # # Loss combination method # loss_method=loss_method, # Options: equal, alex, dwa # # Output layer # output_strategy="linear", # Options: attention, linear dataset=dataset, monitor=monitor, ) model = NetClass(hparams) model.double() tune_metrics = { "loss": "loss", "mcc": "mcc", "acc": "acc", "macroF1": "macroF1" } tune_cb = TuneReportCallback(tune_metrics, on="validation_end") trainer = Trainer(gpus=0, min_epochs=2, max_epochs=100, callbacks=[early_stop_callback, ckp, tune_cb]) trainer.fit(model, train, val)
if batch_idx % self.hparams.sync_batches == 0: self.model.alpha_sync(self.hparams.polyak) return actor_loss_v def validation_step(self, batch, batch_idx): to_log = dict() for k, v in batch.items(): to_log[k] = v.detach().cpu().numpy() to_log['epoch_nr'] = int(self.current_epoch) if self.logger is not None: self.logger.experiment.log(to_log) if __name__ == '__main__': mp.set_start_method('spawn') hparams = get_args() if hparams.debug: hparams.logger = None hparams.profiler = SimpleProfiler() else: hparams.logger = WandbLogger(project=hparams.project) seed_everything(hparams.seed) her = HER(hparams) trainer = pl.Trainer.from_argparse_args(hparams) trainer.callbacks.append(SpawnCallback()) trainer.fit(her)
from torch_geometric.nn import GCNConv, GATConv, GCN2Conv import torch from torch_geometric.data import Data import pytorch_lightning as pl import torch.nn.functional as F import torch.nn as nn from torch_geometric.nn.conv.gcn_conv import gcn_norm import torch_geometric.transforms as T from pytorch_lightning.utilities.seed import seed_everything seed_everything(seed=0) ############################# GCN ########################### class GCN(pl.LightningModule): def __init__(self, num_of_features, hid_size, num_of_classes, activation=F.relu, dropout=0.6): super(GCN, self).__init__() self._layer1 = GCNConv(num_of_features, hid_size) self._activation = activation self._layer2 = GCNConv(hid_size, num_of_classes) self._dropout = dropout def forward(self, data: Data): x = self._layer1(data.x, data.edge_index) z = self._activation(x) z = F.dropout(z, self._dropout)
total_dev_acc = 0 total_samples = 0 len_metrics = len(results.keys()) for k in results.keys(): devacc = results[k]['devacc'] * results[k]['ndev'] total_samples += results[k]['ndev'] total_dev_acc += devacc len_metrics *= total_samples return total_dev_acc / total_samples if __name__ == "__main__": parser = argparse.ArgumentParser() # encoders parser.add_argument("--model", default='awe', type=str) parser.add_argument("--save_results", default='results_senteval/', type=str) parser.add_argument("--checkpoint_path", default='', type=str) parser.add_argument('-p', "--prototype", action='store_true') parser.add_argument("--seed", default=42, type=int) parser.add_argument("--batch", default=64, type=int) args = parser.parse_args() wandb.init(project="atcs-seneval", config=args) wandb.log({"model_name": args.model}) seed.seed_everything(args.seed) args.checkpoint_path = get_checkpoint_path(args) run_seneval(args)
logger=[logger, wandb_logger], gpus=1, num_sanity_val_steps=0, auto_lr_find=True, ) trainer.tune(model) trainer.fit(model) def get_args(): parser = ArgumentParser() parser.add_argument('--epochs', type=int, default=5) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--fold', type=int, default=0) parser.add_argument('--img_size', type=int, default=224) parser.add_argument('--model', type=str, default='efficientnet_b1') parser.add_argument('--lr', type=float, default=2e-3) parser.add_argument('--n_workers', type=int, default=24) parser.add_argument('--data_dir', type=str, required=True, help='dir of data to train on (cell-tiles)') return parser.parse_args() if __name__ == '__main__': seed_everything(144) args = get_args() main(args)
def __init__( self, model_class: Type[LightningModule], datamodule_class: Type[LightningDataModule] = None, save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, trainer_defaults: Dict[str, Any] = None, seed_everything_default: int = None, description: str = 'pytorch-lightning trainer command line tool', env_prefix: str = 'PL', env_parse: bool = False, parser_kwargs: Dict[str, Any] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False ) -> None: """ Receives as input pytorch-lightning classes, which are instantiated using a parsed configuration file and/or command line args and then runs trainer.fit. Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. Example, first implement the ``trainer.py`` tool as:: from mymodels import MyModel from pytorch_lightning.utilities.cli import LightningCLI LightningCLI(MyModel) Then in a shell, run the tool with the desired configuration:: $ python trainer.py --print_config > config.yaml $ nano config.yaml # modify the config as desired $ python trainer.py --cfg config.yaml .. warning:: ``LightningCLI`` is in beta and subject to change. Args: model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class. save_config_callback: A callback class to save the training config. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` seed argument. description: Description of the tool shown when running ``--help``. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate LightningArgumentParser. subclass_mode_model: Whether model can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. subclass_mode_data: Whether datamodule can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. """ assert issubclass(trainer_class, Trainer) assert issubclass(model_class, LightningModule) if datamodule_class is not None: assert issubclass(datamodule_class, LightningDataModule) self.model_class = model_class self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback self.trainer_class = trainer_class self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults self.seed_everything_default = seed_everything_default self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse}) self.init_parser() self.add_core_arguments_to_parser() self.add_arguments_to_parser(self.parser) self.parse_arguments() if self.config['seed_everything'] is not None: seed_everything(self.config['seed_everything'], workers=True) self.before_instantiate_classes() self.instantiate_classes() self.prepare_fit_kwargs() self.before_fit() self.fit() self.after_fit()
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. #rank = 0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? allow_tf32=False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. ): # Initialize. start_time = time.time() #device = torch.device('cuda', rank) #np.random.seed(random_seed * num_gpus + rank) #torch.manual_seed(random_seed * num_gpus + rank) #torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. seed_everything(random_seed) torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. # if rank == 0: # print('Loading training set...') # training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset # training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) # training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) # if rank == 0: # print() # print('Num images: ', len(training_set)) # print('Image shape:', training_set.image_shape) # print('Label shape:', training_set.label_shape) # print() # Construct networks. # if rank == 0: # print('Constructing networks...') training_set_pl = StyleGANDataModule(batch_gpu, training_set_kwargs, data_loader_kwargs) training_set = training_set_pl.training_set common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs) # subclass of torch.nn.Module # # Resume from existing pickle. # if (resume_pkl is not None) and (rank == 0): # print(f'Resuming from "{resume_pkl}"') # with dnnlib.util.open_url(resume_pkl) as f: # resume_data = legacy.load_network_pkl(f) # for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: # misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # # Print network summary tables. # if rank == 0: # z = torch.empty([batch_gpu, G.z_dim], device=device) # c = torch.empty([batch_gpu, G.c_dim], device=device) # img = misc.print_module_summary(G, [z, c]) # misc.print_module_summary(D, [img, c]) # Setup augmentation. # if rank == 0: # print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name( **augment_kwargs) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) # if ada_target is not None: # ada_stats = training_stats.Collector(regex='Loss/signs/real') fid50k = FID(max_real=None, num_gen=50000) ema_kimg /= num_gpus ada_kimg /= num_gpus kimg_per_tick /= num_gpus gpu_stats = GPUStatsMonitor(intra_step_time=True) net = StyleGAN2(G=G, D=D, G_opt_kwargs=G_opt_kwargs, D_opt_kwargs=D_opt_kwargs, augment_pipe=augment_pipe, datamodule=training_set_pl, G_reg_interval=G_reg_interval, D_reg_interval=D_reg_interval, ema_kimg=ema_kimg, ema_rampup=ema_rampup, ada_target=ada_target, ada_interval=ada_interval, ada_kimg=ada_kimg, metrics=[fid50k], kimg_per_tick=kimg_per_tick, random_seed=random_seed, **loss_kwargs) trainer = pl.Trainer(gpus=num_gpus, accelerator='ddp', weights_summary='full', fast_dev_run=10, benchmark=cudnn_benchmark, max_steps=total_kimg // (batch_size) * 1000, plugins=[ DDPPlugin(broadcast_buffers=False, find_unused_parameters=True) ], callbacks=[gpu_stats], accumulate_grad_batches=num_gpus) trainer.fit(net, datamodule=training_set_pl)