def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot self.snapshot_iterations = len(self.trainer.val_dataset) self.snapshot_iterations //= self.training_config.batch_size self.tb_writer = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)
def _load_loggers(self) -> None: self.tb_writer = None if self.training_config.tensorboard: # TODO: @sash PL logger upgrade log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir)
def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot # len would be number of batches per GPU == max updates self.snapshot_iterations = len(self.trainer.val_loader) self.tb_writer = None self.wandb_logger = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration) if self.training_config.wandb.enabled: log_dir = setup_output_folder(folder_only=True) env_wandb_logdir = get_mmf_env(key="wandb_logdir") if env_wandb_logdir: log_dir = env_wandb_logdir self.wandb_logger = WandbLogger( entity=config.training.wandb.entity, config=config, project=config.training.wandb.project, )