Exemple #1
0
def load_checkpoint(model, chkpt_file, optimizer=None):
    """Load a pytorch training checkpoint

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        optimizer: the optimizer to which we will load the serialized state
    """
    compression_scheduler = None
    start_epoch = 0

    if os.path.isfile(chkpt_file):
        msglogger.info("=> loading checkpoint %s", chkpt_file)
        checkpoint = torch.load(chkpt_file,
                                map_location=lambda storage, loc: storage)
        msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(
            k for k in checkpoint.keys())))
        start_epoch = checkpoint['epoch'] + 1
        best_top1 = checkpoint.get('best_top1', None)
        if best_top1 is not None:
            msglogger.info("   best top@1: %.3f", best_top1)

        if 'compression_sched' in checkpoint:
            compression_scheduler = distiller.CompressionScheduler(model)
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'])
            msglogger.info(
                "Loaded compression schedule from checkpoint (epoch %d)",
                checkpoint['epoch'])
        else:
            msglogger.info(
                "Warning: compression schedule data does not exist in the checkpoint"
            )

        if 'thinning_recipes' in checkpoint:
            if 'compression_sched' not in checkpoint:
                raise KeyError(
                    "Found thinning_recipes key, but missing mandatory key compression_sched"
                )
            msglogger.info("Loaded a thinning recipe from the checkpoint")
            # Cache the recipes in case we need them later
            model.thinning_recipes = checkpoint['thinning_recipes']
            distiller.execute_thinning_recipes_list(
                model, compression_scheduler.zeros_mask_dict,
                model.thinning_recipes)

        if 'quantizer_metadata' in checkpoint:
            msglogger.info('Loaded quantizer metadata from the checkpoint')
            qmd = checkpoint['quantizer_metadata']
            quantizer = qmd['type'](model, **qmd['params'])
            quantizer.prepare_model()

        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file,
                       checkpoint['epoch'])

        model.load_state_dict(checkpoint['state_dict'])
        return model, compression_scheduler, start_epoch
    else:
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)
Exemple #2
0
 def _load_and_execute_thinning_recipes():
     msglogger.info("Loaded a thinning recipe from the checkpoint")
     # Cache the recipes in case we need them later
     model.thinning_recipes = checkpoint['thinning_recipes']
     if normalize_dataparallel_keys:
         model.thinning_recipes = [distiller.get_normalized_recipe(recipe)
                                   for recipe in model.thinning_recipes]
     distiller.execute_thinning_recipes_list(model,
                                             compression_scheduler.zeros_mask_dict,
                                             model.thinning_recipes)
Exemple #3
0
def load_checkpoint(model, chkpt_file, optimizer=None):
    """Load a pytorch training checkpoint

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        optimizer: the optimizer to which we will load the serialized state
    """
    compression_scheduler = None
    start_epoch = 0

    if os.path.isfile(chkpt_file):
        msglogger.info("=> loading checkpoint %s", chkpt_file)
        checkpoint = torch.load(chkpt_file)
        start_epoch = checkpoint['epoch'] + 1
        best_top1 = checkpoint.get('best_top1', None)
        if best_top1 is not None:
            msglogger.info("   best top@1: %.3f", best_top1)

        if 'compression_sched' in checkpoint:
            compression_scheduler = distiller.CompressionScheduler(model)
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'])
            msglogger.info(
                "Loaded compression schedule from checkpoint (epoch %d)",
                checkpoint['epoch'])

        if 'thinning_recipes' in checkpoint:
            if 'compression_sched' not in checkpoint:
                raise KeyError(
                    "Found thinning_recipes key, but missing mandatoy key compression_sched"
                )
            msglogger.info("Loaded a thinning recipe from the checkpoint")
            # Cache the recipes in case we need them later
            model.thinning_recipes = checkpoint['thinning_recipes']
            distiller.execute_thinning_recipes_list(
                model, compression_scheduler.zeros_mask_dict,
                model.thinning_recipes)
        else:
            msglogger.info(
                "Warning: compression schedule data does not exist in the checkpoint"
            )
            msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file,
                           checkpoint['epoch'])

        model.load_state_dict(checkpoint['state_dict'])
        return model, compression_scheduler, start_epoch
    else:
        msglogger.info("Error: no checkpoint found at %s", chkpt_file)
        exit(1)
Exemple #4
0
def load_checkpoint(model,
                    chkpt_file,
                    optimizer=None,
                    model_device=None,
                    *,
                    lean_checkpoint=False,
                    strict=False):
    """Load a pytorch training checkpoint.

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        lean_checkpoint: if set, read into model only 'state_dict' field
        optimizer: [deprecated argument]
        model_device [str]: if set, call model.to($model_device)
                This should be set to either 'cpu' or 'cuda'.
    :returns: updated model, compression_scheduler, optimizer, start_epoch
    """
    if not os.path.isfile(chkpt_file):
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)

    msglogger.info("=> loading checkpoint %s", chkpt_file)
    checkpoint = torch.load(chkpt_file,
                            map_location=lambda storage, loc: storage)
    msglogger.info('=> Checkpoint contents:\n%s\n' %
                   get_contents_table(checkpoint))
    if 'extras' in checkpoint:
        msglogger.info("=> Checkpoint['extras'] contents:\n{}\n".format(
            get_contents_table(checkpoint['extras'])))

    if 'state_dict' not in checkpoint:
        raise ValueError(
            "Checkpoint must contain the model parameters under the key 'state_dict'"
        )

    checkpoint_epoch = checkpoint.get('epoch', None)
    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0

    compression_scheduler = None
    normalize_dataparallel_keys = False
    if 'compression_sched' in checkpoint:
        compression_scheduler = distiller.CompressionScheduler(model)
        try:
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        except KeyError as e:
            # A very common source of this KeyError is loading a GPU model on the CPU.
            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
            normalize_dataparallel_keys = True
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        msglogger.info(
            "Loaded compression schedule from checkpoint (epoch {})".format(
                checkpoint_epoch))
    else:
        msglogger.info(
            "Warning: compression schedule data does not exist in the checkpoint"
        )

    if 'thinning_recipes' in checkpoint:
        if 'compression_sched' not in checkpoint:
            raise KeyError(
                "Found thinning_recipes key, but missing mandatory key compression_sched"
            )
        msglogger.info("Loaded a thinning recipe from the checkpoint")
        # Cache the recipes in case we need them later
        model.thinning_recipes = checkpoint['thinning_recipes']
        if normalize_dataparallel_keys:
            model.thinning_recipes = [
                distiller.get_normalized_recipe(recipe)
                for recipe in model.thinning_recipes
            ]
        distiller.execute_thinning_recipes_list(
            model, compression_scheduler.zeros_mask_dict,
            model.thinning_recipes)

    if 'quantizer_metadata' in checkpoint:
        msglogger.info('Loaded quantizer metadata from the checkpoint')
        qmd = checkpoint['quantizer_metadata']
        quantizer = qmd['type'](model, **qmd['params'])
        quantizer.prepare_model(qmd['dummy_input'])

    if normalize_dataparallel_keys:
        checkpoint['state_dict'] = {
            normalize_module_name(k): v
            for k, v in checkpoint['state_dict'].items()
        }
    anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
    if anomalous_keys:
        # This is pytorch 1.1+
        missing_keys, unexpected_keys = anomalous_keys
        if unexpected_keys:
            msglogger.warning(
                "Warning: the loaded checkpoint (%s) contains %d unexpected state keys"
                % (chkpt_file, len(unexpected_keys)))
        if missing_keys:
            raise ValueError(
                "The loaded checkpoint (%s) is missing %d state keys" %
                (chkpt_file, len(missing_keys)))

    if model_device is not None:
        model.to(model_device)

    if lean_checkpoint:
        msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(
            str(chkpt_file)))
        return (model, None, None, 0)

    def _load_optimizer(cls, src_state_dict, model):
        """Initiate optimizer with model parameters and load src_state_dict"""
        # initiate the dest_optimizer with a dummy learning rate,
        # this is required to support SGD.__init__()
        dest_optimizer = cls(model.parameters(), lr=1)
        dest_optimizer.load_state_dict(src_state_dict)
        return dest_optimizer

    try:
        optimizer = _load_optimizer(checkpoint['optimizer_type'],
                                    checkpoint['optimizer_state_dict'], model)
    except KeyError:
        # Older checkpoints do support optimizer loading: They either had an 'optimizer' field
        # (different name) which was not used during the load, or they didn't even checkpoint
        # the optimizer.
        optimizer = None

    if optimizer is not None:
        msglogger.info(
            'Optimizer of type {type} was loaded from checkpoint'.format(
                type=type(optimizer)))
        msglogger.info('Optimizer Args: {}'.format(
            dict((k, v)
                 for k, v in optimizer.state_dict()['param_groups'][0].items()
                 if k != 'params')))
    else:
        msglogger.warning('Optimizer could not be loaded from checkpoint.')

    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(
        f=str(chkpt_file), e=checkpoint_epoch))
    return (model, compression_scheduler, optimizer, start_epoch)
Exemple #5
0
def load_checkpoint(model, chkpt_file, optimizer=None):
    """Load a pytorch training checkpoint

    Args:
        model: the pytorch model to which we will load the parameters
        chkpt_file: the checkpoint file
        optimizer: the optimizer to which we will load the serialized state
    """
    if not os.path.isfile(chkpt_file):
        raise IOError(ENOENT, 'Could not find a checkpoint file at',
                      chkpt_file)

    msglogger.info("=> loading checkpoint %s", chkpt_file)
    checkpoint = torch.load(chkpt_file,
                            map_location=lambda storage, loc: storage)
    msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))

    if 'state_dict' not in checkpoint:
        raise ValueError(
            "Checkpoint must contain the model parameters under the key 'state_dict'"
        )

    checkpoint_epoch = checkpoint.get('epoch', None)
    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0

    best_top1 = checkpoint.get('best_top1', None)
    if best_top1 is not None:
        msglogger.info("   best top@1: %.3f", best_top1)

    compression_scheduler = None
    normalize_dataparallel_keys = False
    if 'compression_sched' in checkpoint:
        compression_scheduler = distiller.CompressionScheduler(model)
        try:
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        except KeyError as e:
            # A very common source of this KeyError is loading a GPU model on the CPU.
            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
            normalize_dataparallel_keys = True
            compression_scheduler.load_state_dict(
                checkpoint['compression_sched'], normalize_dataparallel_keys)
        msglogger.info(
            "Loaded compression schedule from checkpoint (epoch {})".format(
                checkpoint_epoch))
    else:
        msglogger.info(
            "Warning: compression schedule data does not exist in the checkpoint"
        )

    if 'thinning_recipes' in checkpoint:
        if 'compression_sched' not in checkpoint:
            raise KeyError(
                "Found thinning_recipes key, but missing mandatory key compression_sched"
            )
        msglogger.info("Loaded a thinning recipe from the checkpoint")
        # Cache the recipes in case we need them later
        model.thinning_recipes = checkpoint['thinning_recipes']
        if normalize_dataparallel_keys:
            model.thinning_recipes = [
                distiller.get_normalized_recipe(recipe)
                for recipe in model.thinning_recipes
            ]
        distiller.execute_thinning_recipes_list(
            model, compression_scheduler.zeros_mask_dict,
            model.thinning_recipes)

    if 'quantizer_metadata' in checkpoint:
        msglogger.info('Loaded quantizer metadata from the checkpoint')
        qmd = checkpoint['quantizer_metadata']
        quantizer = qmd['type'](model, **qmd['params'])
        quantizer.prepare_model()

    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(
        f=str(chkpt_file), e=checkpoint_epoch))
    if normalize_dataparallel_keys:
        checkpoint['state_dict'] = {
            normalize_module_name(k): v
            for k, v in checkpoint['state_dict'].items()
        }
    model.load_state_dict(checkpoint['state_dict'])
    return (model, compression_scheduler, start_epoch)