def test_unused_subclass_model_hdf5_save_format( self, dummy_tf_subclassed_model, dummy_x_train, dummy_y_train, dummy_x_test, filepath, ): """Test TensorflowModelDataset cannot save subclassed user models in HDF5 format Subclassed model From TF docs First of all, a subclassed model that has never been used cannot be saved. That's because a subclassed model needs to be called on some data in order to create its weights. """ hdf5_data_set = TensorFlowModelDataset(filepath=filepath, save_args={"save_format": "h5"}) # demonstrating is a working model dummy_tf_subclassed_model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1) dummy_tf_subclassed_model.predict(dummy_x_test) pattern = ( r"Saving the model to HDF5 format requires the model to be a Functional model or a " r"Sequential model. It does not work for subclassed models, because such models are " r"defined via the body of a Python method, which isn\'t safely serializable. Consider " r"saving to the Tensorflow SavedModel format \(by setting save_format=\"tf\"\) " r"or using `save_weights`.") with pytest.raises(DataSetError, match=pattern): hdf5_data_set.save(dummy_tf_subclassed_model)
def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.tf" data_set = TensorFlowModelDataset(filepath=filepath) assert data_set._version_cache.currsize == 0 # no cache if unversioned data_set.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) assert data_set._version_cache.currsize == 0
def test_hdf5_save_format(self, dummy_tf_base_model, dummy_x_test, filepath): """Test TensorflowModelDataset can save TF graph models in HDF5 format""" hdf5_dataset = TensorFlowModelDataset( filepath=filepath, save_args={"save_format": "h5"} ) predictions = dummy_tf_base_model.predict(dummy_x_test) hdf5_dataset.save(dummy_tf_base_model) reloaded = hdf5_dataset.load() new_predictions = reloaded.predict(dummy_x_test) np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
def test_fs_args(self, fs_args, mocker): fs_mock = mocker.patch("fsspec.filesystem") TensorFlowModelDataset("test.tf", fs_args=fs_args) fs_mock.assert_called_once_with("file", auto_mkdir=True, storage_option="value")
def test_http_filesystem_no_versioning(self): pattern = r"HTTP\(s\) DataSet doesn't support versioning\." with pytest.raises(DataSetError, match=pattern): TensorFlowModelDataset( filepath="https://example.com/file.tf", version=Version(None, None) )
def test_protocol_usage(self, filepath, instance_type): """Test that can be instantiated with mocked arbitrary file systems. """ data_set = TensorFlowModelDataset(filepath=filepath) assert isinstance(data_set._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] assert str(data_set._filepath) == path assert isinstance(data_set._filepath, PurePosixPath)
def test_protocol_usage(self, filepath, instance_type): """Test that can be instantiated with mocked arbitrary file systems. """ data_set = TensorFlowModelDataset(filepath=filepath) assert isinstance(data_set._fs, instance_type) # _strip_protocol() doesn't strip http(s) protocol if data_set._protocol == "https": path = filepath.split("://")[-1] else: path = data_set._fs._strip_protocol(filepath) assert str(data_set._filepath) == path assert isinstance(data_set._filepath, PurePosixPath)
def versioned_tf_model_dataset(filepath, load_version, save_version): return TensorFlowModelDataset(filepath=filepath, version=Version(load_version, save_version))
def tf_model_dataset(filepath, load_args, save_args, fs_args): return TensorFlowModelDataset(filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args)