def restore_checkpoint(output_location, model, optimizer, iterations_per_epoch): checkpoint_location = paths.checkpoint(output_location) if not get_platform().exists(checkpoint_location): return None, None checkpoint = get_platform().load_model(checkpoint_location, map_location=torch.device('cpu')) # Handle DataParallel. module_in_name = get_platform().is_parallel if module_in_name and not all( k.startswith('module.') for k in checkpoint['model_state_dict']): checkpoint['model_state_dict'] = { 'module.' + k: v for k, v in checkpoint['model_state_dict'].items() } elif all(k.startswith('module.') for k in checkpoint['model_state_dict']) and not module_in_name: checkpoint['model_state_dict'] = { k[len('module.'):]: v for k, v in checkpoint['model_state_dict'].items() } model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) step = Step.from_epoch(checkpoint['ep'], checkpoint['it'], iterations_per_epoch) logger = MetricLogger.create_from_string(checkpoint['logger']) return step, logger
def test_create_from_string(self): logger = TestMetricLogger.create_logger() logger2 = MetricLogger.create_from_string(str(logger)) self.assertEqual(logger.get_data('train_accuracy'), logger2.get_data('train_accuracy')) self.assertEqual(logger.get_data('test_accuracy'), logger2.get_data('test_accuracy')) self.assertEqual(str(logger), str(logger2))