Пример #1
0
 def test_is_tf_release(self):
     mock_tf_module = _FakeTFModule('2.2.2')
     self.assertTrue(
         version_check.is_tensorflow_version_newer('2.2.1', mock_tf_module))
     self.assertTrue(
         version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module))
     self.assertFalse(
         version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module))
     self.assertFalse(
         version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module))
Пример #2
0
 def test_is_tf_nightly(self):
   # TF-nightly modules are always true.
   mock_tf_module = _FakeTFModule('2.2.2-dev202004016')
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module))
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module))
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.2.0', mock_tf_module))
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module))
Пример #3
0
 def test_is_tf_release_candidate(self):
   # Release candidates behave the same as regular releases.
   mock_tf_module = _FakeTFModule('2.2.2-rc2')
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.2.1', mock_tf_module))
   self.assertTrue(
       version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module))
   self.assertFalse(
       version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module))
   self.assertFalse(
       version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module))
Пример #4
0
 def length(ds):
     if version_check.is_tensorflow_version_newer('2.3.0', tf):
         # ds.cardinality() only works for RangeDataset at HEAD,
         # and is not in a released version of TensorFlow yet.
         return ds.cardinality().numpy()
     else:
         return tf.data.experimental.cardinality(ds).numpy()
Пример #5
0
def _compiled_comp_equal(comp_1, comp_2):
    """Returns `True` iff the computations are entirely identical.

  Args:
    comp_1: A `building_blocks.CompiledComputation` to test.
    comp_2: A `building_blocks.CompiledComputation` to test.

  Raises:
    TypeError: if `comp_1` or `comp_2` is not a
      `building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp_1, building_blocks.CompiledComputation)
    py_typecheck.check_type(comp_2, building_blocks.CompiledComputation)

    tensorflow_1 = comp_1.proto.tensorflow
    tensorflow_2 = comp_2.proto.tensorflow
    if tensorflow_1.initialize_op != tensorflow_2.initialize_op:
        return False
    if tensorflow_1.parameter != tensorflow_2.parameter:
        return False
    if tensorflow_1.result != tensorflow_2.result:
        return False

    graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def)
    graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def)
    # TODO(b/174605105): Remove this gating when TFF updates its TensorFlow
    # dependency.
    if version_check.is_tensorflow_version_newer('2.6.0', tf):
        return tf.__internal__.graph_util.graph_defs_equal(
            graphdef_1, graphdef_2, treat_nan_as_equal=True)
    else:
        return graphdef_1.SerializeToString(
            deterministic=True) == graphdef_2.SerializeToString(
                deterministic=True)
Пример #6
0
    def create_tf_dataset_from_all_clients(self,
                                           seed: Optional[int] = None
                                           ) -> tf.data.Dataset:
        """Creates a new `tf.data.Dataset` containing _all_ client examples.

    This function is intended for use training centralized, non-distributed
    models (num_clients=1). This can be useful as a point of comparison
    against federated models.

    Currently, the implementation produces a dataset that contains
    all examples from a single client in order, and so generally additional
    shuffling should be performed.

    Args:
      seed: Optional, a seed to determine the order in which clients are
        processed in the joined dataset. The seed can be any 32-bit unsigned
        integer or an array of such integers.

    Returns:
      A `tf.data.Dataset` object.
    """
        # Note: simply calling Dataset.concatenate() will result in too deep
        # recursion depth.
        # Note: Tests are via the simple concrete from_tensor_slices_client_data.

        # TODO(b/154763092): remove this check and only use the newer path.
        if version_check.is_tensorflow_version_newer('2.3.0', tf):
            logging.info('Using newer tf.data.Dataset construction behavior.')
            # This works in tf-nightly, but isn't in a released tensorflow
            # version yet.
            client_datasets = [d for d in self.datasets(seed=seed)]
            nested_dataset = tf.data.Dataset.from_tensor_slices(
                client_datasets)
            example_dataset = nested_dataset.flat_map(lambda x: x)
        else:
            logging.info(
                'Old TensorFlow version detected; defaulting to slower '
                'tf.data.Dataset construction.')

            def _generator():
                for dataset in self.datasets(seed=seed):
                    for example in dataset:
                        yield example

            types = tf.nest.map_structure(lambda t: t.dtype,
                                          self.element_type_structure)
            shapes = tf.nest.map_structure(lambda t: t.shape,
                                           self.element_type_structure)
            example_dataset = tf.data.Dataset.from_generator(
                _generator, types, shapes)
        return example_dataset