def create_tf_dataset_for_client(self, client_id):
     tf_dataset = self._create_tf_dataset_fn(client_id)
     tensor_utils.check_nested_equal(tf_dataset.output_types,
                                     self._output_types)
     tensor_utils.check_nested_equal(tf_dataset.output_shapes,
                                     self._output_shapes)
     return tf_dataset
Exemplo n.º 2
0
 def create_tf_dataset_for_client(self, client_id):
   tf_dataset = self._create_dataset(client_id)
   tensor_utils.check_nested_equal(
       tf.compat.v1.data.get_output_types(tf_dataset), self._output_types)
   tensor_utils.check_nested_equal(
       tf.compat.v1.data.get_output_shapes(tf_dataset), self._output_shapes)
   return tf_dataset
Exemplo n.º 3
0
 def create_tf_dataset_for_client(self, client_id):
     if client_id not in self.client_ids:
         raise ValueError(
             "ID [{i}] is not a client in this ClientData. See "
             "property `client_ids` for the list of valid ids.".format(
                 i=client_id))
     tf_dataset = self._create_dataset(client_id)
     tensor_utils.check_nested_equal(tf_dataset.element_spec,
                                     self._element_type_structure)
     return tf_dataset
Exemplo n.º 4
0
    def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset:
        """Creates a new `tf.data.Dataset` containing the client training examples.

    This function will create a dataset for a given client if `client_id` is
    contained in the `client_ids` property of the `FilePerUserClientData`.
    Unlike `self.serializable_dataset_fn`, this method is not serializable.

    Args:
      client_id: The string identifier for the desired client.

    Returns:
      A `tf.data.Dataset` object.
    """
        if client_id not in self.client_ids:
            raise ValueError(
                'ID [{i}] is not a client in this ClientData. See '
                'property `client_ids` for the list of valid ids.'.format(
                    i=client_id))

        client_dataset = self.serializable_dataset_fn(tf.constant(client_id))
        tensor_utils.check_nested_equal(client_dataset.element_spec,
                                        self._element_type_structure)
        return client_dataset
Exemplo n.º 5
0
    def test_check_nested_equal(self):
        nested_dict = {
            'KEY1': {
                'NESTED_KEY': 0
            },
            'KEY2': 1,
        }
        nested_list = [('KEY1', ('NESTED_KEY', 0)), ('KEY2', 1)]
        flat_dict = {
            'KEY1': 0,
            'KEY2': 1,
        }
        nested_dtypes = {
            'x': [tf.int32, tf.float32],
            'y': tf.float32,
        }
        nested_shapes = {
            # N.B. tf.TensorShape([None]) == tf.TensorShape([None])
            # returns False, so we can't use a None shape here.
            'x': [[1], [3, 5]],
            'y': [1],
        }

        # Should not raise an exception.
        tensor_utils.check_nested_equal(nested_dict, nested_dict)
        tensor_utils.check_nested_equal(nested_list, nested_list)
        tensor_utils.check_nested_equal(flat_dict, flat_dict)
        tensor_utils.check_nested_equal(nested_dtypes, nested_dtypes)
        tensor_utils.check_nested_equal(nested_shapes, nested_shapes)

        with self.assertRaises(TypeError):
            tensor_utils.check_nested_equal(nested_dict, nested_list)

        with self.assertRaises(ValueError):
            # Different nested structures.
            tensor_utils.check_nested_equal(nested_dict, flat_dict)

        # Same as nested_dict, but using float values. Equality still holds for
        # 0 == 0.0 despite different types.
        nested_dict_different_types = {
            'KEY1': {
                'NESTED_KEY': 0.0
            },
            'KEY2': 1.0,
        }
        nest.assert_same_structure(nested_dict, nested_dict_different_types)

        # Same as nested_dict but with one different value
        nested_dict_different_value = {
            'KEY1': {
                'NESTED_KEY': 0.5
            },
            'KEY2': 1.0,
        }
        with self.assertRaises(ValueError):
            tensor_utils.check_nested_equal(nested_dict,
                                            nested_dict_different_value)

        tensor_utils.check_nested_equal([None], [None])

        def always_neq(x, y):
            del x, y
            return False

        with self.assertRaises(ValueError):
            tensor_utils.check_nested_equal([1], [1], always_neq)
Exemplo n.º 6
0
 def create_tf_dataset_for_client(self, client_id):
     tf_dataset = self._create_dataset(client_id)
     tensor_utils.check_nested_equal(
         tf.data.experimental.get_structure(tf_dataset),
         self._element_type_structure)
     return tf_dataset
Exemplo n.º 7
0
 def create_tf_dataset_for_client(self, client_id):
     tf_dataset = self._create_tf_dataset_fn(client_id)
     tensor_utils.check_nested_equal(tf_dataset.element_spec,
                                     self._element_type_structure)
     return tf_dataset