コード例 #1
0
ファイル: checkpointing.py プロジェクト: sbam13/open_lth
def restore_checkpoint(output_location, model, optimizer,
                       iterations_per_epoch):
    checkpoint_location = paths.checkpoint(output_location)
    if not get_platform().exists(checkpoint_location):
        return None, None
    checkpoint = get_platform().load_model(checkpoint_location,
                                           map_location=torch.device('cpu'))

    # Handle DataParallel.
    module_in_name = get_platform().is_parallel
    if module_in_name and not all(
            k.startswith('module.') for k in checkpoint['model_state_dict']):
        checkpoint['model_state_dict'] = {
            'module.' + k: v
            for k, v in checkpoint['model_state_dict'].items()
        }
    elif all(k.startswith('module.')
             for k in checkpoint['model_state_dict']) and not module_in_name:
        checkpoint['model_state_dict'] = {
            k[len('module.'):]: v
            for k, v in checkpoint['model_state_dict'].items()
        }

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = Step.from_epoch(checkpoint['ep'], checkpoint['it'],
                           iterations_per_epoch)
    logger = MetricLogger.create_from_string(checkpoint['logger'])

    return step, logger
コード例 #2
0
 def test_create_from_string(self):
     logger = TestMetricLogger.create_logger()
     logger2 = MetricLogger.create_from_string(str(logger))
     self.assertEqual(logger.get_data('train_accuracy'),
                      logger2.get_data('train_accuracy'))
     self.assertEqual(logger.get_data('test_accuracy'),
                      logger2.get_data('test_accuracy'))
     self.assertEqual(str(logger), str(logger2))