def set_checkpoint(config): """ Set checkpoint information Parameters ---------- config : CfgNode Model configuration Returns ------- config : CfgNode Updated model configuration """ # If checkpoint is enabled if config.checkpoint.filepath is not '': # Create proper monitor string config.checkpoint.monitor = os.path.join('{}-{}'.format( prepare_dataset_prefix(config.datasets.validation, config.checkpoint.monitor_index), config.checkpoint.monitor)) # Join checkpoint folder with run name config.checkpoint.filepath = os.path.join( config.checkpoint.filepath, config.name, '{epoch:02d}_{%s:.3f}' % config.checkpoint.monitor) # Set s3 url if config.checkpoint.s3_path is not '': config.checkpoint.s3_url = s3_url(config) else: # If not saving checkpoint, do not sync to s3 config.checkpoint.s3_path = '' return config.checkpoint
def prep_logger_and_checkpoint(model): """ Use logger and checkpoint information to update configuration Parameters ---------- model : nn.Module Module to update """ # Change run name to be the wandb assigned name if model.logger and not model.config.wandb.dry_run: model.config.name = model.config.wandb.name = model.logger.run_name model.config.wandb.url = model.logger.run_url # If we are saving models we need to update the path if model.config.checkpoint.filepath is not '': # Change checkpoint filepath filepath = model.config.checkpoint.filepath.split('/') filepath[-2] = model.config.name model.config.checkpoint.filepath = '/'.join(filepath) # Change callback dirpath dirpath = os.path.join( os.path.dirname(model.trainer.checkpoint.dirpath), model.config.name) model.trainer.checkpoint.dirpath = dirpath os.makedirs(dirpath, exist_ok=True) model.config.checkpoint.s3_url = s3_url(model.config) # Log updated configuration model.logger.log_config(model.config)