Beispiel #1
0
def parameter_count_from_model(
    model: Union[model_lib.Model, Callable[[], model_lib.Model]]
) -> collections.OrderedDict:
  """Computes count of trainable parameters for a `model`."""
  weights_type = weights_type_from_model(model)
  trainable_weights_type = weights_type.trainable
  tensors_and_params = type_analysis.count_tensors_in_type(
      trainable_weights_type)
  return tensors_and_params
Beispiel #2
0
    def test_initial_weights_pulled_from_model(self, server_optimizer):

        self.skipTest('b/184855264')

        def _model_fn_with_zero_weights():
            linear_regression_model = model_examples.LinearRegression
            weights = model_utils.ModelWeights.from_model(
                linear_regression_model)
            zero_trainable = [tf.zeros_like(x) for x in weights.trainable]
            zero_non_trainable = [
                tf.zeros_like(x) for x in weights.non_trainable
            ]
            zero_weights = model_utils.ModelWeights(
                trainable=zero_trainable, non_trainable=zero_non_trainable)
            zero_weights.assign_weights_to(linear_regression_model)
            return linear_regression_model

        def _model_fn_with_one_weights():
            linear_regression_model = model_examples.LinearRegression
            weights = model_utils.ModelWeights.from_model(
                linear_regression_model)
            ones_trainable = [tf.ones_like(x) for x in weights.trainable]
            ones_non_trainable = [
                tf.ones_like(x) for x in weights.non_trainable
            ]
            ones_weights = model_utils.ModelWeights(
                trainable=ones_trainable, non_trainable=ones_non_trainable)
            ones_weights.assign_weights_to(linear_regression_model)
            return linear_regression_model

        iterative_process_returning_zeros = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=_model_fn_with_zero_weights,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer())

        iterative_process_returning_ones = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=_model_fn_with_one_weights,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer())

        zero_weights_expected = iterative_process_returning_zeros.initialize(
        ).model
        one_weights_expected = iterative_process_returning_ones.initialize(
        ).model

        self.assertEqual(
            sum(tf.reduce_sum(x) for x in zero_weights_expected.trainable) +
            sum(tf.reduce_sum(x) for x in zero_weights_expected.non_trainable),
            0)
        self.assertEqual(
            sum(tf.reduce_sum(x) for x in one_weights_expected.trainable) +
            sum(tf.reduce_sum(x) for x in one_weights_expected.non_trainable),
            type_analysis.count_tensors_in_type(
                iterative_process_returning_ones.initialize.type_signature.
                result.member.model)['parameters'])
Beispiel #3
0
 def _check_aggregated_scalar_count(self,
                                    aggregator,
                                    max_scalars,
                                    min_scalars=0):
   aggregator = _mrfify_aggregator(aggregator)
   mrf = form_utils.get_map_reduce_form_for_iterative_process(aggregator)
   num_aggregated_scalars = type_analysis.count_tensors_in_type(
       mrf.work.type_signature.result)['parameters']
   self.assertLess(num_aggregated_scalars, max_scalars)
   self.assertGreaterEqual(num_aggregated_scalars, min_scalars)
   return mrf
  def test_skips_unspecified_params(self):
    struct_type = computation_types.StructType([
        ('a', computation_types.TensorType(tf.int32, shape=[2, 2])),
        ('b', computation_types.TensorType(tf.int32, shape=[None, 1]))
    ])

    tensors_and_param_count = type_analysis.count_tensors_in_type(struct_type)

    expected_tensors_and_param_count = collections.OrderedDict(
        num_tensors=2, parameters=4, num_unspecified_tensors=1)
    self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count)
  def test_tensor_filter_only_counts_matching_tensors(self):
    struct_type = computation_types.StructType([
        ('a', computation_types.TensorType(tf.float32, shape=[2, 2])),
        ('b', computation_types.TensorType(tf.int32, shape=[2, 1]))
    ])
    tensor_filter = lambda tensor_type: tensor_type.dtype == tf.float32

    tensors_and_param_count = type_analysis.count_tensors_in_type(
        struct_type, tensor_filter)

    expected_tensors_and_param_count = collections.OrderedDict(
        num_tensors=1, parameters=4, num_unspecified_tensors=0)
    self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count)
Beispiel #6
0
 def test_raises_non_type(self):
     with self.assertRaises(TypeError):
         type_analysis.count_tensors_in_type(0)