def add_experiment_logger(prev_logger: LightningLoggerBase, new_logger: LightningLoggerBase) -> LoggerCollection: # If no logger existed previously don't do anything if not prev_logger: return None if isinstance(prev_logger, LoggerCollection): return LoggerCollection([*prev_logger._logger_iterable, new_logger]) return LoggerCollection([prev_logger, new_logger])
def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" logger1 = CustomLogger() logger2 = CustomLogger() with pytest.deprecated_call( match="`LoggerCollection` is deprecated in v1.6"): logger_collection = LoggerCollection([logger1, logger2]) with pytest.deprecated_call( match="`LoggerCollection` is deprecated in v1.6"): logger_collection_2 = LoggerCollection([logger2]) trainer = Trainer() assert type(trainer.logger) == TensorBoardLogger assert trainer.loggers == [trainer.logger] # Test setters for trainer.logger trainer.logger = logger1 assert trainer.logger == logger1 assert trainer.loggers == [logger1] trainer.logger = logger_collection with pytest.deprecated_call( match="logger` when multiple loggers are configured"): assert trainer.logger._logger_iterable == logger_collection._logger_iterable assert trainer.loggers == [logger1, logger2] # LoggerCollection of size 1 should result in trainer.logger becoming the contained logger. trainer.logger = logger_collection_2 assert trainer.logger == logger2 assert trainer.loggers == [logger2] trainer.logger = None assert trainer.logger is None assert trainer.loggers == [] # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] with pytest.deprecated_call( match="logger` when multiple loggers are configured"): assert trainer.logger._logger_iterable == logger_collection._logger_iterable trainer.loggers = [logger1] assert trainer.loggers == [logger1] assert trainer.logger == logger1 trainer.loggers = [] assert trainer.loggers == [] assert trainer.logger is None trainer.loggers = None assert trainer.loggers == [] assert trainer.logger is None
def configure_logger(self, logger): if logger is True: # default logger self.logger = TensorBoardLogger(save_dir=self.default_root_dir, version=self.slurm_job_id, name='lightning_logs') elif logger is False: self.logger = None else: if isinstance(logger, Iterable): self.logger = LoggerCollection(logger) else: self.logger = logger
def test_logger_collection(): mock1 = MagicMock() mock2 = MagicMock() logger = LoggerCollection([mock1, mock2]) assert logger[0] == mock1 assert logger[1] == mock2 assert logger.experiment[0] == mock1.experiment assert logger.experiment[1] == mock2.experiment assert logger.save_dir is None logger.update_agg_funcs({'test': np.mean}, np.sum) mock1.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum) mock2.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum) logger.agg_and_log_metrics({'test': 2.0}, 4) mock1.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4) mock2.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4) logger.close() mock1.close.assert_called_once() mock2.close.assert_called_once()
def test_logger_collection(): mock1 = MagicMock() mock2 = MagicMock() with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): logger = LoggerCollection([mock1, mock2]) assert logger[0] == mock1 assert logger[1] == mock2 assert logger.experiment[0] == mock1.experiment assert logger.experiment[1] == mock2.experiment assert logger.save_dir is None logger.update_agg_funcs({"test": np.mean}, np.sum) mock1.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum) mock2.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum) logger.log_metrics(metrics={"test": 2.0}, step=4) mock1.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4) mock2.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4) logger.finalize("success") mock1.finalize.assert_called_once() mock2.finalize.assert_called_once()
def test_logger_collection(): mock1 = MagicMock() mock2 = MagicMock() logger = LoggerCollection([mock1, mock2]) assert logger[0] == mock1 assert logger[1] == mock2 assert logger.experiment[0] == mock1.experiment assert logger.experiment[1] == mock2.experiment logger.close() mock1.close.assert_called_once() mock2.close.assert_called_once()
def test_unsupported_logger_warning(tmpdir): monitor = TrainingDataMonitor() trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor]) with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): monitor.on_train_start(trainer, pl_module=None)
def test_logger_collection_unique_versions(): unique_version = "1" logger1 = CustomLogger(version=unique_version) logger2 = CustomLogger(version=unique_version) logger = LoggerCollection([logger1, logger2]) assert logger.version == unique_version
def test_logger_collection_names_order(): loggers = [ CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3") ] with pytest.deprecated_call( match="`LoggerCollection` is deprecated in v1.6"): logger = LoggerCollection(loggers) assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
def test_logger_collection_unique_names(): unique_name = "name1" logger1 = CustomLogger(name=unique_name) logger2 = CustomLogger(name=unique_name) logger = LoggerCollection([logger1, logger2]) assert logger.name == unique_name
def test_logger_collection_unique_versions(): unique_version = "1" logger1 = CustomLogger(version=unique_version) logger2 = CustomLogger(version=unique_version) with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): logger = LoggerCollection([logger1, logger2]) assert logger.version == unique_version
def test_logger_collection_unique_names(): unique_name = "name1" logger1 = CustomLogger(name=unique_name) logger2 = CustomLogger(name=unique_name) with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): logger = LoggerCollection([logger1, logger2]) assert logger.name == unique_name
def test_v1_7_0_lightning_logger_base_close(tmpdir): logger = CustomLogger() with pytest.deprecated_call( match="`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7." ): logger.close() with pytest.deprecated_call( match="`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7." ): logger = LoggerCollection([logger]) logger.close()
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if logger is True: # default logger self.trainer.logger = TensorBoardLogger( save_dir=self.trainer.default_root_dir, version=self.trainer.slurm_job_id, name="lightning_logs" ) elif logger is False: self.trainer.logger = None else: if isinstance(logger, Iterable): self.trainer.logger = LoggerCollection(logger) else: self.trainer.logger = logger
def configure_logger( self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]] ) -> None: if isinstance(logger, bool): # default logger self.trainer.logger = (TensorBoardLogger( save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs") if logger else None) elif isinstance(logger, Iterable): self.trainer.logger = LoggerCollection(logger) else: self.trainer.logger = logger
def configure_logger(self, logger): if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) # default logger self.trainer.logger = TensorBoardLogger( save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' ) elif logger is False: self.trainer.logger = None else: if isinstance(logger, Iterable): self.trainer.logger = LoggerCollection(logger) else: self.trainer.logger = logger
def test_v1_8_0_logger_collection(tmpdir): logger1 = CSVLogger(tmpdir) logger2 = CSVLogger(tmpdir) trainer1 = Trainer(logger=logger1) trainer2 = Trainer(logger=[logger1, logger2]) # Should have no deprecation warning trainer1.logger trainer1.loggers trainer2.loggers with pytest.deprecated_call(match="logger` will return the first logger"): _ = trainer2.logger with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): _ = LoggerCollection([logger1, logger2])
def test_logger_collection_versions_order(): loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")] with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): logger = LoggerCollection(loggers) assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
class TrainerLoggingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class current_epoch: int on_gpu: bool log_gpu_memory:... logger: Union[LightningLoggerBase, bool] tqdm_metrics:... global_step: int proc_rank: int use_dp: bool use_ddp2: bool default_save_path: str slurm_job_id: int num_gpus: int def configure_logger(self, logger): if logger is True: # default logger self.logger = TensorBoardLogger(save_dir=self.default_save_path, version=self.slurm_job_id, name='lightning_logs') self.logger.rank = 0 elif logger is False: self.logger = None else: if isinstance(logger, Iterable): self.logger = LoggerCollection(logger) else: self.logger = logger self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step Args: metrics (dict): Metric values grad_norm_dic (dict): Gradient norms step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ # add gpu memory if self.on_gpu and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) metrics.update(mem_map) # add norms metrics.update(grad_norm_dic) # turn all tensors to scalars scalar_metrics = self.metrics_to_scalars(metrics) if "step" in scalar_metrics and step is None: step = scalar_metrics.pop("step") else: # added metrics by Lightning for convenience metrics['epoch'] = self.current_epoch step = step if step is not None else self.global_step # log actual metrics if self.proc_rank == 0 and self.logger is not None: self.logger.log_metrics(scalar_metrics, step=step) self.logger.save() def add_tqdm_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() self.tqdm_metrics[k] = v def metrics_to_scalars(self, metrics): new_metrics = {} for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() if isinstance(v, dict): v = self.metrics_to_scalars(v) new_metrics[k] = v return new_metrics def process_output(self, output, train=False): """Reduces output according to the training mode. Separates loss from logging and tqdm metrics """ # --------------- # EXTRACT CALLBACK KEYS # --------------- # all keys not progress_bar or log are candidates for callbacks callback_metrics = {} for k, v in output.items(): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus callback_metrics = self.reduce_distributed_output( callback_metrics, num_gpus) for k, v in callback_metrics.items(): if isinstance(v, torch.Tensor): callback_metrics[k] = v.item() # --------------- # EXTRACT PROGRESS BAR KEYS # --------------- try: progress_output = output['progress_bar'] # reduce progress metrics for tqdm when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output( progress_output, num_gpus) progress_bar_metrics = progress_output except Exception: progress_bar_metrics = {} # --------------- # EXTRACT LOGGING KEYS # --------------- # extract metrics to log to experiment try: log_output = output['log'] # reduce progress metrics for tqdm when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output( log_output, num_gpus) log_metrics = log_output except Exception: log_metrics = {} # --------------- # EXTRACT LOSS # --------------- # if output dict doesn't have the keyword loss # then assume the output=loss if scalar loss = None if train: try: loss = output['loss'] except Exception: if isinstance(output, torch.Tensor): loss = output else: raise RuntimeError( 'No `loss` value in the dictionary returned from `model.training_step()`.' ) # when using dp need to reduce the loss if self.use_dp or self.use_ddp2: loss = self.reduce_distributed_output(loss, self.num_gpus) # --------------- # EXTRACT HIDDEN # --------------- hiddens = output.get('hiddens') # use every metric passed in as a candidate for callback callback_metrics.update(progress_bar_metrics) callback_metrics.update(log_metrics) # convert tensors to numpy for k, v in callback_metrics.items(): if isinstance(v, torch.Tensor): callback_metrics[k] = v.item() return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens def reduce_distributed_output(self, output, num_gpus): if num_gpus <= 1: return output # when using DP, we get one output per gpu # average outputs and return if isinstance(output, torch.Tensor): return output.mean() for k, v in output.items(): # recurse on nested dics if isinstance(output[k], dict): output[k] = self.reduce_distributed_output(output[k], num_gpus) # do nothing when there's a scalar elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: pass # reduce only metrics that have the same number of gpus elif output[k].size(0) == num_gpus: reduced = torch.mean(output[k]) output[k] = reduced return output
def train_model( method, data_factory, train_params=None, archi_params=None, method_name=None, method_params=None, seed=98347, fix_few_seed=0, gpus=None, mlflow_uri=None, tensorboard_dir=None, checkpoint_dir=None, fast=False, try_to_resume=True, ): """This is the main function where a single model is created and trained, for a single seed value. Args: method (archis.Method): type of method, used to decide which networks to build and how to use some parameters. data_factory (DataFactory): dataset description to get dataset loaders, as well as useful information for some networks. train_params (dict, optional): Hyperparameters for training (see network config). Defaults to None. archi_params (dict, optional): Parameters of the network (see network config). Defaults to None. method_name (string, optional): A unique name describing the method, with its parameters. Used for logging results. Defaults to None. method_params (dict, optional): Parameters to be fed to the model that are specific to `method`. Defaults to None. seed (int, optional): Global seed for reproducibility. Defaults to 98347. fix_few_seed (int, optional): See for semi-supervised setting, fixing which target samples are labeled. Defaults to 0. gpus (list of int, optional): Which GPU ids to use. Defaults to None. mlflow_uri (int|string, optional): if a string, must be formatted like <uri>:<port>. If a port, will try to log to a MLFlow server on localhost:port. If None, ignores MLFlow logging. Defaults to None. fast (bool, optional): Whether to activate the `fast_dev_run` option of PyTorch-Lightning, training only on 1 batch per epoch for debugging. Defaults to False. Returns: 2-elements tuple containing: - pl.Trainer: object containing the resulting metrics, used for evaluation. - BaseAdaptTrainer: pl.LightningModule object (derived class depending on `method`), containing both the dataset & trained networks. """ if type(method) is str: method = archis.Method(method) if method_name is None: method_name = method.value train_params_local = deepcopy(train_params) set_all_seeds(seed) if fix_few_seed > 0: archi_params["random_state"] = fix_few_seed else: archi_params["random_state"] = seed dataset = data_factory.get_multi_domain_dataset(seed) n_classes, data_dim, args = data_factory.get_data_args() network_factory = NetworkFactory(archi_params) # setup feature extractor feature_network = network_factory.get_feature_extractor(data_dim, *args) # setup classifier feature_dim = feature_network.output_size() classifier_network = network_factory.get_task_classifier(feature_dim, n_classes) method_params = {} if method_params is None else method_params if method.is_mmd_method(): model = archis.create_mmd_based( method=method, dataset=dataset, feature_extractor=feature_network, task_classifier=classifier_network, **method_params, **train_params_local, ) else: critic_input_size = feature_dim # setup critic network if method.is_cdan_method(): if method_params is not None and method_params.get("use_random", False): critic_input_size = method_params["random_dim"] else: critic_input_size = feature_dim * n_classes critic_network = network_factory.get_critic_network(critic_input_size) model = archis.create_dann_like( method=method, dataset=dataset, feature_extractor=feature_network, task_classifier=classifier_network, critic=critic_network, **method_params, **train_params_local, ) data_name = data_factory.get_data_short_name() if checkpoint_dir is not None: path_method_name = re.sub(r"[^-/\w\.]", "_", method_name) full_checkpoint_dir = os.path.join( checkpoint_dir, path_method_name, f"seed_{seed}" ) checkpoint_callback = ModelCheckpoint( filepath=os.path.join(full_checkpoint_dir, "{epoch}"), monitor="last_epoch", mode="max", ) checkpoints = sorted( glob.glob(f"{full_checkpoint_dir}/*.ckpt"), key=os.path.getmtime ) if len(checkpoints) > 0 and try_to_resume: last_checkpoint_file = checkpoints[-1] if method is archis.Method.WDGRL: # WDGRL doesn't resume training gracefully last_epoch = ( train_params_local["nb_init_epochs"] + train_params_local["nb_adapt_epochs"] ) if f"epoch={last_epoch - 1}" not in last_checkpoint_file: last_checkpoint_file = None else: last_checkpoint_file = None else: checkpoint_callback = None last_checkpoint_file = None if mlflow_uri is not None: if mlflow_uri.isdecimal(): mlflow_uri = f"http://127.0.0.1:{mlflow_uri}" mlf_logger = MLFlowLogger( experiment_name=data_name, tracking_uri=mlflow_uri, tags=dict( method=method_name, data_variant=data_factory.get_data_long_name(), script=__file__, ), ) else: mlf_logger = None if tensorboard_dir is not None: tnb_logger = TensorBoardLogger( save_dir=tensorboard_dir, name=f"{data_name}_{method_name}", ) else: tnb_logger = None loggers = [logger for logger in [mlf_logger, tnb_logger] if logger is not None] if len(loggers) == 0: logger = False else: logger = LoggerCollection(loggers) logger.log_hyperparams( { "seed": seed, "feature_network": archi_params["feature"]["name"], "method group": method.value, "method": method_name, "start time": create_timestamp_string("%Y-%m-%d %H:%M:%S"), } ) max_nb_epochs = ( train_params_local["nb_adapt_epochs"] * 5 if method is archis.Method.WDGRLMod else train_params["nb_adapt_epochs"] ) pb_refresh = 1 if len(dataset) < 1000 else 10 row_log_interval = max(10, len(dataset) // train_params_local["batch_size"] // 10) if gpus is not None and len(gpus) > 1 and method is archis.Method.WDGRL: logging.warning("WDGRL is not compatible with multi-GPU.") gpus = [gpus[0]] trainer = pl.Trainer( progress_bar_refresh_rate=pb_refresh, # in steps row_log_interval=row_log_interval, min_epochs=train_params_local["nb_init_epochs"], max_epochs=max_nb_epochs + train_params_local["nb_init_epochs"], early_stop_callback=False, num_sanity_val_steps=5, check_val_every_n_epoch=1, checkpoint_callback=checkpoint_callback, resume_from_checkpoint=last_checkpoint_file, gpus=gpus, logger=logger, weights_summary=None, # 'full' is default fast_dev_run=fast, ) if last_checkpoint_file is None: logging.info(f"Training model with {method.name} {param_to_str(method_params)}") else: logging.info( f"Resuming training with {method.name} {param_to_str(method_params)}, from {last_checkpoint_file}." ) trainer.fit(model) if trainer.interrupted: raise KeyboardInterrupt("Trainer was interrupted and shutdown gracefully.") if logger: logger.log_hyperparams( {"finish time": create_timestamp_string("%Y-%m-%d %H:%M:%S")} ) return trainer, model
def test_logger_collection_versions_order(): loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")] logger = LoggerCollection(loggers) assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
def test_logger_collection_names_order(): loggers = [ CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3") ] logger = LoggerCollection(loggers) assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"