def test_is_tf_release(self): mock_tf_module = _FakeTFModule('2.2.2') self.assertTrue( version_check.is_tensorflow_version_newer('2.2.1', mock_tf_module)) self.assertTrue( version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module)) self.assertFalse( version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module)) self.assertFalse( version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module))
def test_is_tf_nightly(self): # TF-nightly modules are always true. mock_tf_module = _FakeTFModule('2.2.2-dev202004016') self.assertTrue( version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module)) self.assertTrue( version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module)) self.assertTrue( version_check.is_tensorflow_version_newer('2.2.0', mock_tf_module)) self.assertTrue( version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module))
def test_is_tf_release_candidate(self): # Release candidates behave the same as regular releases. mock_tf_module = _FakeTFModule('2.2.2-rc2') self.assertTrue( version_check.is_tensorflow_version_newer('2.2.1', mock_tf_module)) self.assertTrue( version_check.is_tensorflow_version_newer('2.1.0', mock_tf_module)) self.assertFalse( version_check.is_tensorflow_version_newer('2.2.4', mock_tf_module)) self.assertFalse( version_check.is_tensorflow_version_newer('2.3.0', mock_tf_module))
def length(ds): if version_check.is_tensorflow_version_newer('2.3.0', tf): # ds.cardinality() only works for RangeDataset at HEAD, # and is not in a released version of TensorFlow yet. return ds.cardinality().numpy() else: return tf.data.experimental.cardinality(ds).numpy()
def _compiled_comp_equal(comp_1, comp_2): """Returns `True` iff the computations are entirely identical. Args: comp_1: A `building_blocks.CompiledComputation` to test. comp_2: A `building_blocks.CompiledComputation` to test. Raises: TypeError: if `comp_1` or `comp_2` is not a `building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp_1, building_blocks.CompiledComputation) py_typecheck.check_type(comp_2, building_blocks.CompiledComputation) tensorflow_1 = comp_1.proto.tensorflow tensorflow_2 = comp_2.proto.tensorflow if tensorflow_1.initialize_op != tensorflow_2.initialize_op: return False if tensorflow_1.parameter != tensorflow_2.parameter: return False if tensorflow_1.result != tensorflow_2.result: return False graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def) graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def) # TODO(b/174605105): Remove this gating when TFF updates its TensorFlow # dependency. if version_check.is_tensorflow_version_newer('2.6.0', tf): return tf.__internal__.graph_util.graph_defs_equal( graphdef_1, graphdef_2, treat_nan_as_equal=True) else: return graphdef_1.SerializeToString( deterministic=True) == graphdef_2.SerializeToString( deterministic=True)
def create_tf_dataset_from_all_clients(self, seed: Optional[int] = None ) -> tf.data.Dataset: """Creates a new `tf.data.Dataset` containing _all_ client examples. This function is intended for use training centralized, non-distributed models (num_clients=1). This can be useful as a point of comparison against federated models. Currently, the implementation produces a dataset that contains all examples from a single client in order, and so generally additional shuffling should be performed. Args: seed: Optional, a seed to determine the order in which clients are processed in the joined dataset. The seed can be any 32-bit unsigned integer or an array of such integers. Returns: A `tf.data.Dataset` object. """ # Note: simply calling Dataset.concatenate() will result in too deep # recursion depth. # Note: Tests are via the simple concrete from_tensor_slices_client_data. # TODO(b/154763092): remove this check and only use the newer path. if version_check.is_tensorflow_version_newer('2.3.0', tf): logging.info('Using newer tf.data.Dataset construction behavior.') # This works in tf-nightly, but isn't in a released tensorflow # version yet. client_datasets = [d for d in self.datasets(seed=seed)] nested_dataset = tf.data.Dataset.from_tensor_slices( client_datasets) example_dataset = nested_dataset.flat_map(lambda x: x) else: logging.info( 'Old TensorFlow version detected; defaulting to slower ' 'tf.data.Dataset construction.') def _generator(): for dataset in self.datasets(seed=seed): for example in dataset: yield example types = tf.nest.map_structure(lambda t: t.dtype, self.element_type_structure) shapes = tf.nest.map_structure(lambda t: t.shape, self.element_type_structure) example_dataset = tf.data.Dataset.from_generator( _generator, types, shapes) return example_dataset