def parameter_count_from_model( model: Union[model_lib.Model, Callable[[], model_lib.Model]] ) -> collections.OrderedDict: """Computes count of trainable parameters for a `model`.""" weights_type = weights_type_from_model(model) trainable_weights_type = weights_type.trainable tensors_and_params = type_analysis.count_tensors_in_type( trainable_weights_type) return tensors_and_params
def test_initial_weights_pulled_from_model(self, server_optimizer): self.skipTest('b/184855264') def _model_fn_with_zero_weights(): linear_regression_model = model_examples.LinearRegression weights = model_utils.ModelWeights.from_model( linear_regression_model) zero_trainable = [tf.zeros_like(x) for x in weights.trainable] zero_non_trainable = [ tf.zeros_like(x) for x in weights.non_trainable ] zero_weights = model_utils.ModelWeights( trainable=zero_trainable, non_trainable=zero_non_trainable) zero_weights.assign_weights_to(linear_regression_model) return linear_regression_model def _model_fn_with_one_weights(): linear_regression_model = model_examples.LinearRegression weights = model_utils.ModelWeights.from_model( linear_regression_model) ones_trainable = [tf.ones_like(x) for x in weights.trainable] ones_non_trainable = [ tf.ones_like(x) for x in weights.non_trainable ] ones_weights = model_utils.ModelWeights( trainable=ones_trainable, non_trainable=ones_non_trainable) ones_weights.assign_weights_to(linear_regression_model) return linear_regression_model iterative_process_returning_zeros = optimizer_utils.build_model_delta_optimizer_process( model_fn=_model_fn_with_zero_weights, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer()) iterative_process_returning_ones = optimizer_utils.build_model_delta_optimizer_process( model_fn=_model_fn_with_one_weights, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer()) zero_weights_expected = iterative_process_returning_zeros.initialize( ).model one_weights_expected = iterative_process_returning_ones.initialize( ).model self.assertEqual( sum(tf.reduce_sum(x) for x in zero_weights_expected.trainable) + sum(tf.reduce_sum(x) for x in zero_weights_expected.non_trainable), 0) self.assertEqual( sum(tf.reduce_sum(x) for x in one_weights_expected.trainable) + sum(tf.reduce_sum(x) for x in one_weights_expected.non_trainable), type_analysis.count_tensors_in_type( iterative_process_returning_ones.initialize.type_signature. result.member.model)['parameters'])
def _check_aggregated_scalar_count(self, aggregator, max_scalars, min_scalars=0): aggregator = _mrfify_aggregator(aggregator) mrf = form_utils.get_map_reduce_form_for_iterative_process(aggregator) num_aggregated_scalars = type_analysis.count_tensors_in_type( mrf.work.type_signature.result)['parameters'] self.assertLess(num_aggregated_scalars, max_scalars) self.assertGreaterEqual(num_aggregated_scalars, min_scalars) return mrf
def test_skips_unspecified_params(self): struct_type = computation_types.StructType([ ('a', computation_types.TensorType(tf.int32, shape=[2, 2])), ('b', computation_types.TensorType(tf.int32, shape=[None, 1])) ]) tensors_and_param_count = type_analysis.count_tensors_in_type(struct_type) expected_tensors_and_param_count = collections.OrderedDict( num_tensors=2, parameters=4, num_unspecified_tensors=1) self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count)
def test_tensor_filter_only_counts_matching_tensors(self): struct_type = computation_types.StructType([ ('a', computation_types.TensorType(tf.float32, shape=[2, 2])), ('b', computation_types.TensorType(tf.int32, shape=[2, 1])) ]) tensor_filter = lambda tensor_type: tensor_type.dtype == tf.float32 tensors_and_param_count = type_analysis.count_tensors_in_type( struct_type, tensor_filter) expected_tensors_and_param_count = collections.OrderedDict( num_tensors=1, parameters=4, num_unspecified_tensors=0) self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count)
def test_raises_non_type(self): with self.assertRaises(TypeError): type_analysis.count_tensors_in_type(0)