def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # Save latest checkpoint rank_zero_warn('Saving latest checkpoint..') self.check_checkpoint_callback(should_check_val=False, force_save=True) # hook self.trainer.call_hook('on_train_end') # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results if self.trainer.global_rank == 0: self.trainer.profiler.describe() if self.trainer.global_rank == 0: for proc in self.trainer.interactive_ddp_procs: subprocess.Popen.kill(proc) # clean up dist group if self.trainer.use_ddp or self.trainer.use_ddp2: torch_distrib.destroy_process_group() # clear mem if self.trainer.on_gpu: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache()
def compute_max_steps(hparams: Namespace, trainer: Trainer) -> int: r""" Compute total number of steps if not specified by the user. They may be required for example by eventual schedulers or optimizers. """ # if already defined, skip if hparams.max_steps is not None: return hparams.max_steps if not hasattr(trainer, 'datamodule'): rank_zero_warn( "You tried to fix `max_steps` but didn't provide a datamodule to " "the trainer.fit function. Returning `max_steps=None`") return None dataset_len = len(trainer.datamodule.train_dataset) total_devices = utils.get_total_devices(trainer=trainer) num_training_batches = math.ceil(dataset_len / hparams.batch_size) training_batches_per_epoch = num_training_batches // total_devices steps_per_epoch = math.ceil(training_batches_per_epoch / hparams.accumulate_grad_batches) steps = hparams.max_epochs * steps_per_epoch rank_zero_warn( f"Automagically computed max_steps={steps}. If it appears to be OK, ignore this warning" ) return steps
def select_accelerator(self) -> Accelerator: if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user if self._precision_plugin is not None or self._training_type_plugin is not None: # plugins also specified by user rank_zero_warn( 'Specified `Precision` and `TrainingType` plugins will be ignored,' ' since an `Accelerator` instance was provided.' ) return self.distributed_backend if self.on_gpu: acc_cls = GPUAccelerator elif self.on_tpu: acc_cls = TPUAccelerator elif self.on_ipu: acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator # as precision_plugin is dependent on training_type_plugin, make sure # that we first select training_type_plugin, then precision_plugin return acc_cls( training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin, )
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.distributed_backend.lower() not in [ 'ddp_spawn', 'ddp_cpu', 'tpu' ]: return # track the best model path best_model_path = None if self.checkpoint_callback is not None: best_model_path = self.checkpoint_callback.best_model_path if self.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), last_path) mp_queue.put(last_path)
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.trainer.distributed_backend.lower() not in [ 'ddp_spawn', 'ddp_cpu', 'tpu' ]: return # track the best model path best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.trainer.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path)
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): closure_loss = amp.scale_loss(closure_loss, optimizer) # enter apex context self.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX)) context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass if self.trainer.train_loop.automatic_optimization: model = self.trainer.get_model() model.backward(closure_loss, optimizer, opt_idx) else: closure_loss.backward(*args, **kwargs) # exit amp context a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception('apex unscale error') # once backward has been applied, release graph closure_loss = closure_loss.detach() return closure_loss
def _process_result(self, training_step_output, split_batch): training_step_output.track_batch_size(len(split_batch)) m = """ TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0. Use self.log and .write from the LightningModule to log metrics and write predictions. training_step can now only return a scalar (for the loss) or a dictionary with anything you want. Option 1: return loss Option 2: return {'loss': loss, 'anything_else': ...} Option 3: return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...} """ rank_zero_warn(m) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException( "training_step cannot return EvalResult, " "use a dict or TrainResult instead") training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() return training_step_output_for_epoch_end
def __init__( self, embedding_dim: Optional[int] = None, backbone: str = "swav-imagenet", pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, pooling_fn: Callable = torch.max ): super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, ) self.save_hyperparameters() self.backbone_name = backbone self.embedding_dim = embedding_dim assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn self.backbone, num_features = backbone_and_num_features(backbone, pretrained) if embedding_dim is None: self.head = nn.Identity() else: self.head = nn.Sequential( nn.Flatten(), nn.Linear(num_features, embedding_dim), ) rank_zero_warn('embedding_dim is not None. Remember to finetune first!')
def pre_configure_ddp(self): # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " "to properly work with DDP.") self._ddp_kwargs["find_unused_parameters"] = True
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.on_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() elif self.precision == 64: return DoublePrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) elif not _NATIVE_AMP_AVAILABLE: msg = "You have asked for native AMP but your PyTorch version does not support it." \ " Consider upgrading with `pip install torch>=1.6`." if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin() if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP," " please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise NotImplementedError("We only support precisions 64, 32 and 16!")
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # track the best model path best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results)
def _logger_is_supported(self, trainer): """This callback only works with wandb logger""" for logger_type in self.SUPPORTED_LOGGERS: if isinstance(trainer.logger, logger_type): return True rank_zero_warn( f"Unsupported logger: '{trainer.logger}', will not log any media to logger this run." f" Supported loggers: {[sup_log.__name__ for sup_log in self.SUPPORTED_LOGGERS]}." ) return False
def __init__(self, *datasets): super().__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) if not utils.functional.all_equal_in_iterable([len(d) for d in self.datasets]): rank_zero_warn( "Datasets do not have all the same length: " ", ".join([f"{d.__class__.__name__}: {len(d)}" for d in self.datasets]) )
def __init__(self, log_dir: str) -> None: self.hparams = {} self.metrics = [] self.log_dir = log_dir if os.path.exists(self.log_dir) and os.listdir(self.log_dir): rank_zero_warn( f"Experiment logs directory {self.log_dir} exists and is not empty." " Previous log files in this directory will be deleted when the new ones are saved!" ) os.makedirs(self.log_dir, exist_ok=True) self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
def on_trace_ready(profiler): filename = f"{action_name}_{self.local_rank}" if self.dirpath is not None: if self._export_to_chrome: handler = tensorboard_trace_handler(self.dirpath, filename) handler(profiler) if self._export_to_flame_graph: path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) profiler.export_stacks(path, metric=self._metric) else: rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
def log_graph(self, model: LightningModule, input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array if input_array is not None: self.experiment.add_graph(model, model._apply_batch_transfer_handler(input_array)) else: rank_zero_warn( 'Could not log computational graph since neither the' ' `model.example_input_array` attribute is set nor' ' `input_array` was given', UserWarning )
def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False): rank_zero_warn( "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " "to properly work with DDP.") self._ddp_kwargs["find_unused_parameters"] = True
def _map_deprecated_dist_backend(self, accelerator, distributed_backend): if distributed_backend is not None: rank_zero_warn(DeprecationWarning('distributed_backend has been renamed to accelerator. ' 'Deprecated in 1.0.0, will be removed in 1.2.0')) # temporary mapping until we remove all the distributed_backend references if accelerator is not None: self.accelerator = accelerator if isinstance(accelerator, Accelerator): self.accelerator.trainer = self distributed_backend = self.accelerator.nickname else: distributed_backend = accelerator return distributed_backend
def select_precision_plugin(self) -> PrecisionPlugin: if self.precision == 32: self.amp_type = None return PrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == "native": if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`." " We will attempt to use NVIDIA Apex for this session." ) if not _APEX_AVAILABLE and self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) self.amp_type = "apex" elif self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) else: log.info("Using native 16bit precision.") if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() self.amp_type = AMPType.NATIVE return NativeMixedPrecisionPlugin() if self.amp_type == "apex": if not _APEX_AVAILABLE: rank_zero_warn( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) else: if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, " "please using native AMP for 16-bit precision.") log.info("Using APEX 16bit precision.") self.amp_type = AMPType.APEX return ApexMixedPrecisionPlugin(self.amp_level) else: raise NotImplementedError("We only support precisions 32 and 16!")
def log_graph(self, model: LightningModule, input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array if input_array is not None: self.experiment.add_graph( model, model.transfer_batch_to_device(model.example_input_array, model.device)) else: rank_zero_warn( 'Could not log computational graph since the' ' `model.example_input_array` attribute is not set' ' or `input_array` was not given', UserWarning)
def __init__(self, embedding_dim: Optional[int] = None, backbone: str = "swav-imagenet", pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, pooling_fn: Callable = torch.max): super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, ) self.save_hyperparameters() self.backbone_name = backbone self.embedding_dim = embedding_dim assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn if backbone in _models: config = _load_model(backbone) self.backbone = config['model'] num_features = config['num_features'] elif backbone not in _backbones: raise NotImplementedError( f"Backbone {backbone} is not yet supported") else: backbone_fn, split, num_feats = _backbones[backbone] backbone = backbone_fn(pretrained=pretrained) self.backbone = split(backbone) num_features = num_feats(backbone) if embedding_dim is None: self.head = nn.Identity() else: self.head = nn.Sequential( nn.Flatten(), nn.Linear(num_features, embedding_dim), ) rank_zero_warn( 'embedding_dim is not None. Remember to finetune first!')
def __init__( self, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): rank_zero_warn( "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." " It will be removed in v1.5.0", DeprecationWarning) super(Metric, self).__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, )
def _save_media_to_disk(self, trainer, pred_data: Optional[PredData], mode: Mode, batch_idx: Optional[int] = None): """For a given mode (train/val/test), save the results to disk""" if not self.save_to_disk: return if pred_data is None: # Empty queue rank_zero_warn(f"Empty queue! Mode: {mode}") return # Create output filename if self.save_latest_only: output_filename = f"results.{mode.name.lower()}.png" else: if batch_idx is None: output_filename = f"results-epoch{trainer.current_epoch}.{mode.name.lower()}.png" else: output_filename = f"results-epoch{trainer.current_epoch}-step{batch_idx}.{mode.name.lower()}.png" output_filename = self.exp_dir / output_filename # Get the latest batches from the data queue in LightningModule inputs, labels, preds = pred_data.inputs, pred_data.labels, pred_data.preds # Colorize labels and predictions label2rgb = LabelToRGB() labels_rgb = [label2rgb.map_color_palette(lbl, Palette.LAPA) for lbl in labels] preds_rgb = [label2rgb.map_color_palette(pred, Palette.LAPA) for pred in preds] inputs_l = [ipt for ipt in inputs] # Create collage of results results_l = [] # Combine each pair of inp/lbl/pred into singe image for inp, lbl, pred in zip(inputs_l, labels_rgb, preds_rgb): res_combined = np.concatenate((inp, lbl, pred), axis=1) results_l.append(res_combined) # Create grid from combined imgs n_imgs = len(results_l) n_cols = 4 # Fix num of columns n_rows = int(math.ceil(n_imgs / n_cols)) img_h, img_w, _ = results_l[0].shape grid_results = np.zeros((img_h * n_rows, img_w * n_cols, 3), dtype=np.uint8) for idy in range(n_rows): for idx in range(n_cols): grid_results[idy * img_h : (idy + 1) * img_h, idx * img_w : (idx + 1) * img_w, :] = results_l[idx + idy] # Save collage if not cv2.imwrite(str(output_filename), cv2.cvtColor(grid_results, cv2.COLOR_RGB2BGR)): rank_zero_warn(f"Error in writing image: {output_filename}")
def add_model_specific_args(parser: ArgumentParser): r""" Usual parameters used by AdamW and LinearScheduler. Moreover, it checks the learning rate is at reasonable values. """ parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--max_sequence_length', type=int, default=128) parser.add_argument('--weight_decay', type=float, default=0.0) parser.add_argument('--adam_epsilon', type=float, default=1e-8) parser.add_argument('--adam_betas', nargs=2, type=float, default=[0.9, 0.999]) parser.add_argument('--max_grad_norm', type=float, default=1e-8) parser.add_argument('--warmup_steps', type=int, default=0) parser.add_argument('--beg_scheduler_step', type=int, default=0) tmp_args, _ = parser.parse_known_args() if tmp_args.learning_rate > 1: rank_zero_warn(f"You specified a huge learning rate! Learning rate: {tmp_args.learning_rate}") return parser
def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results)
def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights last_path = None if (self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results)
def transfer_ddp_spawn_state_on_fit_end(self, model, q, results): if self.distributed_backend not in ['ddp_spawn', 'ddp_cpu', 'tpu']: return # track the best model path best_model_path = None if self.checkpoint_callback is not None: best_model_path = self.checkpoint_callback.best_model_path if self.global_rank == 0 and q is not None: rank_zero_warn('cleaning up ddp environment...') q.put(best_model_path) q.put(results) # save the last weights last_path = None if not self.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) torch.save(model.state_dict(), last_path) q.put(last_path)
def __deprecation_check( self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]], ) -> Set[str]: if record_functions is None: record_functions = set() if profiled_functions is not None: rank_zero_warn( "`PyTorchProfiler.profiled_functions` has been renamed to" " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning) if not record_functions: record_functions |= set(profiled_functions) else: raise MisconfigurationException( "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." " Please use only the later.") return record_functions
def __init__( self, embedding_dim: Optional[int] = None, backbone: str = "resnet101", pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, pooling_fn: Callable = torch.max ): if not _IMAGE_AVAILABLE: raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'") super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, preprocess=ImageClassificationPreprocess() ) self.save_hyperparameters() self.backbone_name = backbone self.embedding_dim = embedding_dim assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained) if embedding_dim is None: self.head = nn.Identity() else: self.head = nn.Sequential( nn.Flatten(), nn.Linear(num_features, embedding_dim), ) rank_zero_warn('embedding_dim. Remember to finetune first!')
def __init__(self, log_dir: str, metrics_file="metrics.csv", hparams_file="hparams.yaml") -> None: self.NAME_HPARAMS_FILE = hparams_file self.NAME_METRICS_FILE = metrics_file self.hparams = {} self.metrics = [] self.log_dir = log_dir self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) if os.path.exists(self.log_dir) and os.path.exists( self.metrics_file_path): rank_zero_warn( f"Experiment logs directory {self.log_dir} exists and is not empty." " Loading previous results.") with io.open(self.metrics_file_path, "r") as f: metrics_keys = f.readline()[:-1].split(",") with io.open(self.metrics_file_path, "r", newline="") as f: reader = csv.DictReader(f, fieldnames=metrics_keys) next(reader) for l in reader: if hasattr(l, "step"): step = l["step"] del l["step"] else: step = None self.log_metrics(l, step) self.log_hparams( load_hparams_from_yaml( os.path.join(self.log_dir, self.NAME_HPARAMS_FILE))) os.makedirs(self.log_dir, exist_ok=True)