def test_restore_v2(self): """Test that restoring a v2 style checkpoint works.""" model_lib_v2.load_fine_tune_checkpoint( self._model, self._ckpt_path, checkpoint_type='', checkpoint_version=train_pb2.CheckpointVersion.V2, input_dataset=self._train_input_fn(), unpad_groundtruth_tensors=True) np.testing.assert_allclose(self._model.weight.numpy(), 42)
def test_restore_map_incompatible_error(self): """Test that restoring an incompatible restore map causes an error.""" with self.assertRaisesRegex( TypeError, r'.*received a \(str -> ResourceVariable\).*'): model_lib_v2.load_fine_tune_checkpoint( IncompatibleModel(), self._ckpt_path, checkpoint_type='', checkpoint_version=train_pb2.CheckpointVersion.V2, input_dataset=self._train_input_fn(), unpad_groundtruth_tensors=True)