def test_construct_with_non_dict(self): with self.assertRaisesRegex(TypeError, r'Expected collections.abc.Mapping'): file_per_user_client_data.FilePerUserClientData( client_ids_to_files=[], # Not a dict. dataset_fn=tf.data.TFRecordDataset, )
def test_construct_with_non_callable(self): fake_user_data = FilePerUserClientDataTest.fake_user_data with self.assertRaisesRegex(TypeError, r'found non-callable'): file_per_user_client_data.FilePerUserClientData( client_ids_to_files=fake_user_data.client_data_file_dict, dataset_fn=None, )
def _create_fake_client_data(self): fake_user_data = FilePerUserClientDataTest.fake_user_data return file_per_user_client_data.FilePerUserClientData( fake_user_data.client_data_file_dict, fake_user_data.create_test_dataset_fn, )
def test_construct_with_non_list(self): with six.assertRaisesRegex(self, TypeError, r'Expected list, found dict'): file_per_user_client_data.FilePerUserClientData( client_ids={}, # Not a list. create_tf_dataset_fn=tf.data.TFRecordDataset)
def test_construct_with_non_callable(self): with six.assertRaisesRegex(self, TypeError, r'found non-callable'): file_per_user_client_data.FilePerUserClientData( client_ids=FilePerUserClientDataTest.fake_user_data.client_ids, create_tf_dataset_fn=None)
def _create_fake_client_data(self): fake_user_data = FilePerUserClientDataTest.fake_user_data return file_per_user_client_data.FilePerUserClientData( client_ids=fake_user_data.client_ids, create_tf_dataset_fn=fake_user_data.create_test_dataset_fn)