예제 #1
0
def load_checkpoint(model,
                    chkpt_file,
                    optimizer=None,
                    model_device=None,
                    *,
                    lean_checkpoint=False,
                    strict=False):
    """Load a pytorch training checkpoint.

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        lean_checkpoint: if set, read into model only 'state_dict' field
        optimizer: [deprecated argument]
        model_device [str]: if set, call model.to($model_device)
                This should be set to either 'cpu' or 'cuda'.
    :returns: updated model, compression_scheduler, optimizer, start_epoch
    """
    if not os.path.isfile(chkpt_file):
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)

    msglogger.info("=> loading checkpoint %s", chkpt_file)
    checkpoint = torch.load(chkpt_file,
                            map_location=lambda storage, loc: storage)
    msglogger.info('=> Checkpoint contents:\n%s\n' %
                   get_contents_table(checkpoint))
    if 'extras' in checkpoint:
        msglogger.info("=> Checkpoint['extras'] contents:\n{}\n".format(
            get_contents_table(checkpoint['extras'])))

    if 'state_dict' not in checkpoint:
        raise ValueError(
            "Checkpoint must contain the model parameters under the key 'state_dict'"
        )

    checkpoint_epoch = checkpoint.get('epoch', None)
    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0

    compression_scheduler = None
    normalize_dataparallel_keys = False
    if 'compression_sched' in checkpoint:
        compression_scheduler = distiller.CompressionScheduler(model)
        try:
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        except KeyError as e:
            # A very common source of this KeyError is loading a GPU model on the CPU.
            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
            normalize_dataparallel_keys = True
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        msglogger.info(
            "Loaded compression schedule from checkpoint (epoch {})".format(
                checkpoint_epoch))
    else:
        msglogger.info(
            "Warning: compression schedule data does not exist in the checkpoint"
        )

    if 'thinning_recipes' in checkpoint:
        if 'compression_sched' not in checkpoint:
            raise KeyError(
                "Found thinning_recipes key, but missing mandatory key compression_sched"
            )
        msglogger.info("Loaded a thinning recipe from the checkpoint")
        # Cache the recipes in case we need them later
        model.thinning_recipes = checkpoint['thinning_recipes']
        if normalize_dataparallel_keys:
            model.thinning_recipes = [
                distiller.get_normalized_recipe(recipe)
                for recipe in model.thinning_recipes
            ]
        distiller.execute_thinning_recipes_list(
            model, compression_scheduler.zeros_mask_dict,
            model.thinning_recipes)

    if 'quantizer_metadata' in checkpoint:
        msglogger.info('Loaded quantizer metadata from the checkpoint')
        qmd = checkpoint['quantizer_metadata']
        quantizer = qmd['type'](model, **qmd['params'])
        quantizer.prepare_model(qmd['dummy_input'])

    if normalize_dataparallel_keys:
        checkpoint['state_dict'] = {
            normalize_module_name(k): v
            for k, v in checkpoint['state_dict'].items()
        }
    anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
    if anomalous_keys:
        # This is pytorch 1.1+
        missing_keys, unexpected_keys = anomalous_keys
        if unexpected_keys:
            msglogger.warning(
                "Warning: the loaded checkpoint (%s) contains %d unexpected state keys"
                % (chkpt_file, len(unexpected_keys)))
        if missing_keys:
            raise ValueError(
                "The loaded checkpoint (%s) is missing %d state keys" %
                (chkpt_file, len(missing_keys)))

    if model_device is not None:
        model.to(model_device)

    if lean_checkpoint:
        msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(
            str(chkpt_file)))
        return (model, None, None, 0)

    def _load_optimizer(cls, src_state_dict, model):
        """Initiate optimizer with model parameters and load src_state_dict"""
        # initiate the dest_optimizer with a dummy learning rate,
        # this is required to support SGD.__init__()
        dest_optimizer = cls(model.parameters(), lr=1)
        dest_optimizer.load_state_dict(src_state_dict)
        return dest_optimizer

    try:
        optimizer = _load_optimizer(checkpoint['optimizer_type'],
                                    checkpoint['optimizer_state_dict'], model)
    except KeyError:
        # Older checkpoints do support optimizer loading: They either had an 'optimizer' field
        # (different name) which was not used during the load, or they didn't even checkpoint
        # the optimizer.
        optimizer = None

    if optimizer is not None:
        msglogger.info(
            'Optimizer of type {type} was loaded from checkpoint'.format(
                type=type(optimizer)))
        msglogger.info('Optimizer Args: {}'.format(
            dict((k, v)
                 for k, v in optimizer.state_dict()['param_groups'][0].items()
                 if k != 'params')))
    else:
        msglogger.warning('Optimizer could not be loaded from checkpoint.')

    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(
        f=str(chkpt_file), e=checkpoint_epoch))
    return (model, compression_scheduler, optimizer, start_epoch)
예제 #2
0
def load_checkpoint(model,
                    chkpt_file,
                    optimizer=None,
                    model_device=None,
                    lean_checkpoint=False,
                    strict=False):
    """Load a pytorch training checkpoint.

    Args:
        model: the pytorch model to which we will load the parameters.  You can
        specify model=None if the checkpoint contains enough metadata to infer
        the model.  The order of the arguments is misleading and clunky, and is
        kept this way for backward compatibility.
        chkpt_file: the checkpoint file
        lean_checkpoint: if set, read into model only 'state_dict' field
        optimizer: [deprecated argument]
        model_device [str]: if set, call model.to($model_device)
                This should be set to either 'cpu' or 'cuda'.
    :returns: updated model, compression_scheduler, optimizer, start_epoch
    """
    def _load_compression_scheduler():
        normalize_keys = False
        try:
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_keys)
        except KeyError as e:
            # A very common source of this KeyError is loading a GPU model on the CPU.
            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
            normalize_keys = True
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_keys)
        msglogger.info(
            "Loaded compression schedule from checkpoint (epoch {})".format(
                checkpoint_epoch))
        return normalize_keys

    def _load_and_execute_thinning_recipes():
        msglogger.info("Loaded a thinning recipe from the checkpoint")
        # Cache the recipes in case we need them later
        model.thinning_recipes = checkpoint['thinning_recipes']
        if normalize_dataparallel_keys:
            model.thinning_recipes = [
                distiller.get_normalized_recipe(recipe)
                for recipe in model.thinning_recipes
            ]
        distiller.execute_thinning_recipes_list(
            model, compression_scheduler.zeros_mask_dict,
            model.thinning_recipes)

    def _load_optimizer():
        """Initialize optimizer with model parameters and load src_state_dict"""
        try:
            cls, src_state_dict = checkpoint['optimizer_type'], checkpoint[
                'optimizer_state_dict']
            # Initialize the dest_optimizer with a dummy learning rate,
            # this is required to support SGD.__init__()
            dest_optimizer = cls(model.parameters(), lr=1)
            dest_optimizer.load_state_dict(src_state_dict)
            msglogger.info(
                'Optimizer of type {type} was loaded from checkpoint'.format(
                    type=type(dest_optimizer)))
            optimizer_param_groups = dest_optimizer.state_dict(
            )['param_groups']
            msglogger.info('Optimizer Args: {}'.format(
                dict((k, v) for k, v in optimizer_param_groups[0].items()
                     if k != 'params')))
            return dest_optimizer
        except KeyError:
            # Older checkpoints do support optimizer loading: They either had an 'optimizer' field
            # (different name) which was not used during the load, or they didn't even checkpoint
            # the optimizer.
            msglogger.warning('Optimizer could not be loaded from checkpoint.')
            return None

    def _create_model_from_ckpt():
        try:
            return distiller.models.create_model(False,
                                                 checkpoint['dataset'],
                                                 checkpoint['arch'],
                                                 checkpoint['is_parallel'],
                                                 device_ids=None)
        except KeyError:
            return None

    def _sanity_check():
        try:
            if model.arch != checkpoint["arch"]:
                raise ValueError(
                    "The model architecture does not match the checkpoint architecture"
                )
        except (NameError, KeyError):
            # One of the values is missing so we can't perform the comparison
            pass

    if not os.path.isfile(chkpt_file):
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)
    assert optimizer == None, "argument optimizer is deprecated and must be set to None"

    msglogger.info("=> loading checkpoint %s", chkpt_file)
    checkpoint = torch.load(chkpt_file,
                            map_location=lambda storage, loc: storage)
    msglogger.info('=> Checkpoint contents:\n%s\n' %
                   get_contents_table(checkpoint))
    if 'extras' in checkpoint:
        msglogger.info("=> Checkpoint['extras'] contents:\n{}\n".format(
            get_contents_table(checkpoint['extras'])))

    if 'state_dict' not in checkpoint:
        raise ValueError(
            "Checkpoint must contain the model parameters under the key 'state_dict'"
        )

    if not model:
        model = _create_model_from_ckpt()
        if not model:
            raise ValueError(
                "You didn't provide a model, and the checkpoint %s doesn't contain "
                "enough information to create one", chkpt_file)

    checkpoint_epoch = checkpoint.get('epoch', None)
    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
    compression_scheduler = None
    normalize_dataparallel_keys = False
    if 'compression_sched' in checkpoint:
        compression_scheduler = distiller.CompressionScheduler(model)
        normalize_dataparallel_keys = _load_compression_scheduler()
    else:
        msglogger.info(
            "Warning: compression schedule data does not exist in the checkpoint"
        )

    if 'thinning_recipes' in checkpoint:
        if not compression_scheduler:
            msglogger.warning(
                "Found thinning_recipes key, but missing key compression_scheduler"
            )
            compression_scheduler = distiller.CompressionScheduler(model)
        _load_and_execute_thinning_recipes()

    if 'quantizer_metadata' in checkpoint:
        msglogger.info('Loaded quantizer metadata from the checkpoint')
        qmd = checkpoint['quantizer_metadata']
        quantizer = qmd['type'](model, **qmd['params'])
        quantizer.prepare_model(qmd['dummy_input'])

    if normalize_dataparallel_keys:
        checkpoint['state_dict'] = {
            normalize_module_name(k): v
            for k, v in checkpoint['state_dict'].items()
        }
    anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
    if anomalous_keys:
        # This is pytorch 1.1+
        missing_keys, unexpected_keys = anomalous_keys
        if unexpected_keys:
            msglogger.warning(
                "Warning: the loaded checkpoint (%s) contains %d unexpected state keys"
                % (chkpt_file, len(unexpected_keys)))
        if missing_keys:
            raise ValueError(
                "The loaded checkpoint (%s) is missing %d state keys" %
                (chkpt_file, len(missing_keys)))

    if model_device is not None:
        model.to(model_device)

    if lean_checkpoint:
        msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(
            str(chkpt_file)))
        return model, None, None, 0

    optimizer = _load_optimizer()
    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(
        f=str(chkpt_file), e=checkpoint_epoch))
    _sanity_check()
    return model, compression_scheduler, optimizer, start_epoch
예제 #3
0
def load_checkpoint(model, chkpt_file, optimizer=None):
    """Load a pytorch training checkpoint

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        optimizer: the optimizer to which we will load the serialized state
    """
    if not os.path.isfile(chkpt_file):
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)

    msglogger.info("=> loading checkpoint %s", chkpt_file)
    checkpoint = torch.load(chkpt_file,
                            map_location=lambda storage, loc: storage)
    msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))

    if 'state_dict' not in checkpoint:
        raise ValueError(
            "Checkpoint must contain the model parameters under the key 'state_dict'"
        )

    checkpoint_epoch = checkpoint.get('epoch', None)
    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0

    best_top1 = checkpoint.get('best_top1', None)
    if best_top1 is not None:
        msglogger.info("   best top@1: %.3f", best_top1)

    compression_scheduler = None
    normalize_dataparallel_keys = False
    if 'compression_sched' in checkpoint:
        compression_scheduler = distiller.CompressionScheduler(model)
        try:
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        except KeyError as e:
            # A very common source of this KeyError is loading a GPU model on the CPU.
            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
            normalize_dataparallel_keys = True
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        msglogger.info(
            "Loaded compression schedule from checkpoint (epoch {})".format(
                checkpoint_epoch))
    else:
        msglogger.info(
            "Warning: compression schedule data does not exist in the checkpoint"
        )

    if 'thinning_recipes' in checkpoint:
        if 'compression_sched' not in checkpoint:
            raise KeyError(
                "Found thinning_recipes key, but missing mandatory key compression_sched"
            )
        msglogger.info("Loaded a thinning recipe from the checkpoint")
        # Cache the recipes in case we need them later
        model.thinning_recipes = checkpoint['thinning_recipes']
        if normalize_dataparallel_keys:
            model.thinning_recipes = [
                distiller.get_normalized_recipe(recipe)
                for recipe in model.thinning_recipes
            ]
        distiller.execute_thinning_recipes_list(
            model, compression_scheduler.zeros_mask_dict,
            model.thinning_recipes)

    if 'quantizer_metadata' in checkpoint:
        msglogger.info('Loaded quantizer metadata from the checkpoint')
        qmd = checkpoint['quantizer_metadata']
        quantizer = qmd['type'](model, **qmd['params'])
        quantizer.prepare_model()

    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(
        f=str(chkpt_file), e=checkpoint_epoch))
    if normalize_dataparallel_keys:
        checkpoint['state_dict'] = {
            normalize_module_name(k): v
            for k, v in checkpoint['state_dict'].items()
        }
    model.load_state_dict(checkpoint['state_dict'])
    return (model, compression_scheduler, start_epoch)