def wrap_training_step(wrapped, instance: pl.LightningModule, args, kwargs): output_dict = wrapped(*args, **kwargs) if isinstance(output_dict, dict) and output_dict is not None and 'log' in output_dict: log_dict = output_dict.pop('log') instance.log_dict(log_dict, on_step=True) return output_dict
def common_epoch_end(self, step_outputs, prefix="train/", exclude_keys={"pred", "target"}): keys = list(step_outputs[0].keys()) mean_step_metrics = { k: torch.mean(torch.stack([x[k] for x in step_outputs])) for k in keys if k not in exclude_keys } preds, targets = zip(*[(s["pred"], s["target"]) for s in step_outputs]) preds = cat_steps(preds) targets = cat_steps(targets) epoch_metrics = prefix_keys( prefix, self.collect_epoch_metrics(preds, targets, prefix.replace("/", "")), ) epoch_metrics, epoch_figures = sort_out_figures(epoch_metrics) all_metrics = {**mean_step_metrics, **epoch_metrics} LightningModule.log_dict(self, all_metrics, sync_dist=self._sync_dist) log_figures(self, epoch_figures)
def wrap_training_step(wrapped, instance: LightningModule, args, kwargs): """ Wraps the training step of the LightningModule. Parameters ---------- wrapped: The wrapped function. instance: The LightningModule instance. args: The arguments passed to the wrapped function. kwargs: The keyword arguments passed to the wrapped function. Returns ------- The return value of the wrapped function. """ output_dict = wrapped(*args, **kwargs) if isinstance(output_dict, dict) and output_dict is not None and "log" in output_dict: log_dict = output_dict.pop("log") instance.log_dict(log_dict, on_step=True) return output_dict
def override_unsupported_nud(lm: pl.LightningModule, context: PyTorchTrialContext) -> None: writer = pytorch.TorchWriter() def lm_print(*args: Any, **kwargs: Any) -> None: if context.distributed.get_rank() == 0: print(*args, **kwargs) def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None: if len(args) != 0 or len(kwargs) != 0: raise InvalidModelException( f"unsupported arguments to LightningModule.log {args} {kwargs}" ) for metric, value in a_dict.items(): if type(value) == int or type(value) == float: writer.add_scalar(metric, value, context.current_train_batch()) def lm_log(name: str, value: Any, *args: Any, **kwargs: Any) -> None: lm_log_dict({name: value}, *args, **kwargs) lm.print = lm_print # type: ignore lm.log = lm_log # type: ignore lm.log_dict = lm_log_dict # type: ignore