示例#1
0
def set_checkpoint(config):
    """
    Set checkpoint information

    Parameters
    ----------
    config : CfgNode
        Model configuration

    Returns
    -------
    config : CfgNode
        Updated model configuration
    """
    # If checkpoint is enabled
    if config.checkpoint.filepath is not '':
        # Create proper monitor string
        config.checkpoint.monitor = os.path.join('{}-{}'.format(
            prepare_dataset_prefix(config.datasets.validation,
                                   config.checkpoint.monitor_index),
            config.checkpoint.monitor))
        # Join checkpoint folder with run name
        config.checkpoint.filepath = os.path.join(
            config.checkpoint.filepath, config.name,
            '{epoch:02d}_{%s:.3f}' % config.checkpoint.monitor)
        # Set s3 url
        if config.checkpoint.s3_path is not '':
            config.checkpoint.s3_url = s3_url(config)
    else:
        # If not saving checkpoint, do not sync to s3
        config.checkpoint.s3_path = ''
    return config.checkpoint
示例#2
0
def prep_logger_and_checkpoint(model):
    """
    Use logger and checkpoint information to update configuration

    Parameters
    ----------
    model : nn.Module
        Module to update
    """
    # Change run name to be the wandb assigned name
    if model.logger and not model.config.wandb.dry_run:
        model.config.name = model.config.wandb.name = model.logger.run_name
        model.config.wandb.url = model.logger.run_url
        # If we are saving models we need to update the path
        if model.config.checkpoint.filepath is not '':
            # Change checkpoint filepath
            filepath = model.config.checkpoint.filepath.split('/')
            filepath[-2] = model.config.name
            model.config.checkpoint.filepath = '/'.join(filepath)
            # Change callback dirpath
            dirpath = os.path.join(
                os.path.dirname(model.trainer.checkpoint.dirpath),
                model.config.name)
            model.trainer.checkpoint.dirpath = dirpath
            os.makedirs(dirpath, exist_ok=True)
            model.config.checkpoint.s3_url = s3_url(model.config)
        # Log updated configuration
        model.logger.log_config(model.config)