Esempio n. 1
0
 def test_next_computation_returning_tensor_fails_well(self):
     mrf = mapreduce_test_utils.get_temperature_sensor_example()
     it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
     init_result = it.initialize.type_signature.result
     lam = building_blocks.Lambda(
         'x', init_result, building_blocks.Reference('x', init_result))
     bad_it = iterative_process.IterativeProcess(
         it.initialize,
         computation_impl.ConcreteComputation.from_building_block(lam))
     with self.assertRaises(TypeError):
         form_utils.get_map_reduce_form_for_iterative_process(bad_it)
Esempio n. 2
0
    def get_map_reduce_form_for_client_to_server_fn(
            self, client_to_server_fn) -> forms.MapReduceForm:
        """Produces a `MapReduceForm` for the provided `client_to_server_fn`.

    Creates an `iterative_process.IterativeProcess` which uses
    `client_to_server_fn` to map from `client_data` to `server_output`, then
    passes this value through `get_map_reduce_form_for_iterative_process`.

    Args:
      client_to_server_fn: A function from client-placed data to server-placed
        output.

    Returns:
      A `forms.MapReduceForm` which uses the embedded `client_to_server_fn`.
    """
        @federated_computation.federated_computation
        def init_fn():
            return intrinsics.federated_value((), placements.SERVER)

        @federated_computation.federated_computation([
            computation_types.at_server(()),
            computation_types.at_clients(tf.int32),
        ])
        def next_fn(server_state, client_data):
            server_output = client_to_server_fn(client_data)
            return server_state, server_output

        ip = iterative_process.IterativeProcess(init_fn, next_fn)
        return form_utils.get_map_reduce_form_for_iterative_process(ip)
Esempio n. 3
0
    def test_returns_map_reduce_form_with_indirection_to_intrinsic(self):
        ip = mapreduce_test_utils.get_iterative_process_for_example_with_lambda_returning_aggregation(
        )

        mrf = form_utils.get_map_reduce_form_for_iterative_process(ip)

        self.assertIsInstance(mrf, forms.MapReduceForm)
Esempio n. 4
0
    def test_returns_canonical_form_with_grappler_disabled(self, ip):
        grappler_config = tf.compat.v1.ConfigProto()
        grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True
        mrf = form_utils.get_map_reduce_form_for_iterative_process(
            ip, grappler_config)

        self.assertIsInstance(mrf, forms.MapReduceForm)
Esempio n. 5
0
    def test_mnist_training_round_trip(self):
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_mnist_training_example())

        # TODO(b/208887729): We disable grappler to work around attempting to hoist
        # transformed functions of the same name into the eager context. When this
        # execution is C++-backed, this can go away.
        grappler_config = tf.compat.v1.ConfigProto()
        grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True
        mrf = form_utils.get_map_reduce_form_for_iterative_process(
            it, grappler_config)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)

        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertAllClose(state1, state2)
        whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
        whimsy_y = np.array([1], dtype=np.int32)
        client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]
        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        self.assertAllClose(state, alt_state)
        alt_metrics = alt_round_1[1]
        self.assertAllClose(metrics, alt_metrics)
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
Esempio n. 6
0
 def _check_aggregated_scalar_count(self,
                                    aggregator,
                                    max_scalars,
                                    min_scalars=0):
   aggregator = _mrfify_aggregator(aggregator)
   mrf = form_utils.get_map_reduce_form_for_iterative_process(aggregator)
   num_aggregated_scalars = type_analysis.count_tensors_in_type(
       mrf.work.type_signature.result)['parameters']
   self.assertLess(num_aggregated_scalars, max_scalars)
   self.assertGreaterEqual(num_aggregated_scalars, min_scalars)
   return mrf
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    mrf = mapreduce_test_utils.get_temperature_sensor_example()
    it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    next_comp = it.next.to_building_block()
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Struct([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
Esempio n. 8
0
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems
        # to lose the python container annotations on the StructType.
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example())
        mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
        state = new_it.initialize()
        self.assertEqual(state['num_rounds'], 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state['num_rounds'], 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
 def test_mnist_training_round_trip(self):
   it = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_mnist_training_example())
   mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
   new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertAllClose(state1, state2)
   whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
   whimsy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   self.assertAllClose(state, alt_state)
   alt_metrics = alt_round_1[1]
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           it.next.to_building_block()),
       tree_analysis.count_tensorflow_variables_under(
           new_it.next.to_building_block()))
Esempio n. 10
0
 def test_returns_map_reduce_form_for_sum_example_with_no_aggregation(self):
     ip = get_iterative_process_for_sum_example_with_no_aggregation()
     mrf = form_utils.get_map_reduce_form_for_iterative_process(ip)
     self.assertIsInstance(mrf, forms.MapReduceForm)
Esempio n. 11
0
    def test_returns_map_reduce_form(self, ip):
        mrf = form_utils.get_map_reduce_form_for_iterative_process(ip)

        self.assertIsInstance(mrf, forms.MapReduceForm)
Esempio n. 12
0
 def test_constructs_map_reduce_form_from_mnist_training_example(self):
     it = form_utils.get_iterative_process_for_map_reduce_form(
         mapreduce_test_utils.get_mnist_training_example())
     mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
     self.assertIsInstance(mrf, forms.MapReduceForm)
Esempio n. 13
0
 def test_gets_map_reduce_form_for_nested_broadcast(self):
     ip = get_iterative_process_with_nested_broadcasts()
     mrf = form_utils.get_map_reduce_form_for_iterative_process(ip)
     self.assertIsInstance(mrf, forms.MapReduceForm)