def create_tf_dataset_for_client(self, client_id): tf_dataset = self._create_tf_dataset_fn(client_id) tensor_utils.check_nested_equal(tf_dataset.output_types, self._output_types) tensor_utils.check_nested_equal(tf_dataset.output_shapes, self._output_shapes) return tf_dataset
def create_tf_dataset_for_client(self, client_id): tf_dataset = self._create_dataset(client_id) tensor_utils.check_nested_equal( tf.compat.v1.data.get_output_types(tf_dataset), self._output_types) tensor_utils.check_nested_equal( tf.compat.v1.data.get_output_shapes(tf_dataset), self._output_shapes) return tf_dataset
def create_tf_dataset_for_client(self, client_id): if client_id not in self.client_ids: raise ValueError( "ID [{i}] is not a client in this ClientData. See " "property `client_ids` for the list of valid ids.".format( i=client_id)) tf_dataset = self._create_dataset(client_id) tensor_utils.check_nested_equal(tf_dataset.element_spec, self._element_type_structure) return tf_dataset
def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset: """Creates a new `tf.data.Dataset` containing the client training examples. This function will create a dataset for a given client if `client_id` is contained in the `client_ids` property of the `FilePerUserClientData`. Unlike `self.serializable_dataset_fn`, this method is not serializable. Args: client_id: The string identifier for the desired client. Returns: A `tf.data.Dataset` object. """ if client_id not in self.client_ids: raise ValueError( 'ID [{i}] is not a client in this ClientData. See ' 'property `client_ids` for the list of valid ids.'.format( i=client_id)) client_dataset = self.serializable_dataset_fn(tf.constant(client_id)) tensor_utils.check_nested_equal(client_dataset.element_spec, self._element_type_structure) return client_dataset
def test_check_nested_equal(self): nested_dict = { 'KEY1': { 'NESTED_KEY': 0 }, 'KEY2': 1, } nested_list = [('KEY1', ('NESTED_KEY', 0)), ('KEY2', 1)] flat_dict = { 'KEY1': 0, 'KEY2': 1, } nested_dtypes = { 'x': [tf.int32, tf.float32], 'y': tf.float32, } nested_shapes = { # N.B. tf.TensorShape([None]) == tf.TensorShape([None]) # returns False, so we can't use a None shape here. 'x': [[1], [3, 5]], 'y': [1], } # Should not raise an exception. tensor_utils.check_nested_equal(nested_dict, nested_dict) tensor_utils.check_nested_equal(nested_list, nested_list) tensor_utils.check_nested_equal(flat_dict, flat_dict) tensor_utils.check_nested_equal(nested_dtypes, nested_dtypes) tensor_utils.check_nested_equal(nested_shapes, nested_shapes) with self.assertRaises(TypeError): tensor_utils.check_nested_equal(nested_dict, nested_list) with self.assertRaises(ValueError): # Different nested structures. tensor_utils.check_nested_equal(nested_dict, flat_dict) # Same as nested_dict, but using float values. Equality still holds for # 0 == 0.0 despite different types. nested_dict_different_types = { 'KEY1': { 'NESTED_KEY': 0.0 }, 'KEY2': 1.0, } nest.assert_same_structure(nested_dict, nested_dict_different_types) # Same as nested_dict but with one different value nested_dict_different_value = { 'KEY1': { 'NESTED_KEY': 0.5 }, 'KEY2': 1.0, } with self.assertRaises(ValueError): tensor_utils.check_nested_equal(nested_dict, nested_dict_different_value) tensor_utils.check_nested_equal([None], [None]) def always_neq(x, y): del x, y return False with self.assertRaises(ValueError): tensor_utils.check_nested_equal([1], [1], always_neq)
def create_tf_dataset_for_client(self, client_id): tf_dataset = self._create_dataset(client_id) tensor_utils.check_nested_equal( tf.data.experimental.get_structure(tf_dataset), self._element_type_structure) return tf_dataset
def create_tf_dataset_for_client(self, client_id): tf_dataset = self._create_tf_dataset_fn(client_id) tensor_utils.check_nested_equal(tf_dataset.element_spec, self._element_type_structure) return tf_dataset