Exemple #1
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    self._check_value_type_compatible_with_config_mode(value_type)

    @computations.federated_computation(self._init_fn.type_signature.result,
                                        computation_types.FederatedType(
                                            value_type, placements.CLIENTS))
    def next_fn(state, value):
      # Server-side preparation.
      upper_bound, lower_bound = self._get_bounds_from_state(state)

      # Compute min and max *before* clipping and use it to update the state.
      value_max = intrinsics.federated_map(_reduce_nest_max, value)
      value_min = intrinsics.federated_map(_reduce_nest_min, value)
      new_state = self._update_state(state, value_min, value_max)

      # Clips value to [lower_bound, upper_bound] and securely sums it.
      summed_value = self._sum_securely(value, upper_bound, lower_bound)

      # TODO(b/163880757): pass upper_bound and lower_bound through clients.
      measurements = self._compute_measurements(upper_bound, lower_bound,
                                                value_max, value_min)
      return measured_process.MeasuredProcessOutput(new_state, summed_value,
                                                    measurements)

    return aggregation_process.AggregationProcess(self._init_fn, next_fn)
 def _normalize_reference_bit(comp):
   if not comp.type_signature.is_federated():
     return comp, False
   return building_blocks.Reference(
       comp.name,
       computation_types.FederatedType(comp.type_signature.member,
                                       comp.type_signature.placement)), True
Exemple #3
0
def _create_complex_computation():
    tensor_type = computation_types.TensorType(tf.int32)
    compiled = building_block_factory.create_compiled_identity(
        tensor_type, 'a')
    federated_type = computation_types.FederatedType(tf.int32,
                                                     placements.SERVER)
    arg_ref = building_blocks.Reference('arg', federated_type)
    bindings = []
    results = []

    def _bind(name, value):
        bindings.append((name, value))
        return building_blocks.Reference(name, value.type_signature)

    for i in range(2):
        called_federated_broadcast = building_block_factory.create_federated_broadcast(
            arg_ref)
        called_federated_map = building_block_factory.create_federated_map(
            compiled, _bind(f'broadcast_{i}', called_federated_broadcast))
        called_federated_mean = building_block_factory.create_federated_mean(
            _bind(f'map_{i}', called_federated_map), None)
        results.append(_bind(f'mean_{i}', called_federated_mean))
    result = building_blocks.Struct(results)
    block = building_blocks.Block(bindings, result)
    return building_blocks.Lambda('arg', tf.int32, block)
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:
        type_args = typing.get_args(factory.ValueType)
        py_typecheck.check_type(value_type, type_args)

        @federated_computation.federated_computation()
        def init_fn():
            return intrinsics.federated_value(0, placements.SERVER)

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.FederatedType(value_type, placements.CLIENTS))
        def next_fn(state, value):
            state = intrinsics.federated_map(
                tensorflow_computation.tf_computation(lambda x: x + 1), state)
            result = intrinsics.federated_map(
                tensorflow_computation.tf_computation(
                    lambda x: tf.nest.map_structure(lambda y: y + 1, x)),
                intrinsics.federated_sum(value))
            measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT,
                                                      placements.SERVER)
            return measured_process.MeasuredProcessOutput(
                state, result, measurements)

        return aggregation_process.AggregationProcess(init_fn, next_fn)
 async def _zip(self, arg, placement, all_equal):
     self._check_arg_is_structure(arg)
     py_typecheck.check_type(placement, placements.PlacementLiteral)
     self._check_strategy_compatible_with_placement(placement)
     children = self._target_executors[placement]
     cardinality = len(children)
     elements = structure.to_elements(arg.internal_representation)
     for _, v in elements:
         py_typecheck.check_type(v, list)
         if len(v) != cardinality:
             raise RuntimeError('Expected {} items, found {}.'.format(
                 cardinality, len(v)))
     new_vals = []
     for idx in range(cardinality):
         new_vals.append(
             structure.Struct([(k, v[idx]) for k, v in elements]))
     new_vals = await asyncio.gather(
         *[c.create_struct(x) for c, x in zip(children, new_vals)])
     return FederatedResolvingStrategyValue(
         new_vals,
         computation_types.FederatedType(computation_types.StructType(
             ((k, v.member) if k else v.member
              for k, v in structure.iter_elements(arg.type_signature))),
                                         placement,
                                         all_equal=all_equal))
Exemple #6
0
def _create_stateless_int_dataset_reduction_iterative_process():

  @tensorflow_computation.tf_computation()
  def make_zero():
    return tf.cast(0, tf.int64)

  @federated_computation.federated_computation()
  def init():
    return intrinsics.federated_eval(make_zero, placements.SERVER)

  @tensorflow_computation.tf_computation(
      computation_types.SequenceType(tf.int64))
  def reduce_dataset(x):
    return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y)

  @federated_computation.federated_computation(
      (init.type_signature.result,
       computation_types.FederatedType(
           computation_types.SequenceType(tf.int64), placements.CLIENTS)))
  def next_fn(server_state, client_data):
    del server_state  # Unused
    return intrinsics.federated_sum(
        intrinsics.federated_map(reduce_dataset, client_data))

  return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn)
    async def reduce(
        self,
        val: List[executor_value_base.ExecutorValue],
        zero: executor_value_base.ExecutorValue,
        op: pb.Computation,
        op_type: computation_types.FunctionType,
    ) -> FederatedResolvingStrategyValue:
        server = self._target_executors[placements.SERVER][0]

        async def _move(v):
            return await server.create_value(await v.compute(),
                                             v.type_signature)

        item_futures = asyncio.as_completed([_move(v) for v in val])
        zero_at_server = await server.create_value(await zero.compute(),
                                                   zero.type_signature)
        op_at_server = await server.create_value(op, op_type)

        result = zero_at_server
        for item_future in item_futures:
            item = await item_future
            result = await server.create_call(
                op_at_server, await server.create_struct(
                    structure.Struct([(None, result), (None, item)])))
        return FederatedResolvingStrategyValue([result],
                                               computation_types.FederatedType(
                                                   result.type_signature,
                                                   placements.SERVER,
                                                   all_equal=True))
 def _normalize_intrinsic_bit(comp):
   """Replaces federated map all equal with federated map."""
   if comp.uri != intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri:
     return comp, False
   parameter_type = [
       comp.type_signature.parameter[0],
       computation_types.FederatedType(comp.type_signature.parameter[1].member,
                                       placements.CLIENTS)
   ]
   intrinsic_type = computation_types.FunctionType(
       parameter_type,
       computation_types.FederatedType(comp.type_signature.result.member,
                                       placements.CLIENTS))
   new_intrinsic = building_blocks.Intrinsic(intrinsic_defs.FEDERATED_MAP.uri,
                                             intrinsic_type)
   return new_intrinsic, True
    async def _map(self, arg, all_equal=None):
        self._check_arg_is_structure(arg)
        py_typecheck.check_len(arg.internal_representation, 2)
        fn_type = arg.type_signature[0]
        py_typecheck.check_type(fn_type, computation_types.FunctionType)
        val_type = arg.type_signature[1]
        py_typecheck.check_type(val_type, computation_types.FederatedType)
        if all_equal is None:
            all_equal = val_type.all_equal
        elif all_equal and not val_type.all_equal:
            raise ValueError(
                'Cannot map a non-all_equal argument into an all_equal result.'
            )
        fn = arg.internal_representation[0]
        py_typecheck.check_type(fn, pb.Computation)
        val = arg.internal_representation[1]
        py_typecheck.check_type(val, list)
        for v in val:
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
        self._check_strategy_compatible_with_placement(val_type.placement)
        children = self._target_executors[val_type.placement]

        async def _map_child(fn, fn_type, value, child):
            fn_at_child = await child.create_value(fn, fn_type)
            return await child.create_call(fn_at_child, value)

        results = await asyncio.gather(*[
            _map_child(fn, fn_type, value, child)
            for (value, child) in zip(val, children)
        ])
        return FederatedResolvingStrategyValue(
            results,
            computation_types.FederatedType(fn_type.result,
                                            val_type.placement,
                                            all_equal=all_equal))
def get_iterative_process_for_sum_example():
    """Returns an iterative process for a sum example.

  This iterative process contains all the components required to compile to
  `forms.MapReduceForm`.
  """
    @federated_computation.federated_computation
    def init_fn():
        """The `init` function for `tff.templates.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)

    @tensorflow_computation.tf_computation([tf.int32, tf.int32])
    def prepare(server_state):
        return server_state

    @tensorflow_computation.tf_computation(tf.int32, [tf.int32, tf.int32])
    def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return 1, 1

    @tensorflow_computation.tf_computation([tf.int32, tf.int32],
                                           [tf.int32, tf.int32])
    def update(server_state, global_update):
        del server_state  # Unused
        return global_update, []

    @federated_computation.federated_computation([
        computation_types.FederatedType([tf.int32, tf.int32],
                                        placements.SERVER),
        computation_types.FederatedType(tf.int32, placements.CLIENTS),
    ])
    def next_fn(server_state, client_data):
        """The `next` function for `tff.templates.IterativeProcess`."""
        s2 = intrinsics.federated_map(prepare, server_state)
        client_input = intrinsics.federated_broadcast(s2)
        c3 = intrinsics.federated_zip([client_data, client_input])
        client_updates = intrinsics.federated_map(work, c3)
        unsecure_update = intrinsics.federated_sum(client_updates[0])
        secure_update = intrinsics.federated_secure_sum_bitwidth(
            client_updates[1], 8)
        s6 = intrinsics.federated_zip(
            [server_state, [unsecure_update, secure_update]])
        new_server_state, server_output = intrinsics.federated_map(update, s6)
        return new_server_state, server_output

    return iterative_process.IterativeProcess(init_fn, next_fn)
Exemple #11
0
    def create(
        self, value_type: factory.ValueType, weight_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)
        py_typecheck.check_type(weight_type, factory.ValueType.__args__)

        value_sum_process = self._value_sum_factory.create(value_type)
        weight_sum_process = self._weight_sum_factory.create(weight_type)

        @computations.federated_computation()
        def init_fn():
            state = collections.OrderedDict(
                value_sum_process=value_sum_process.initialize(),
                weight_sum_process=weight_sum_process.initialize())
            return intrinsics.federated_zip(state)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.FederatedType(value_type, placements.CLIENTS),
            computation_types.FederatedType(weight_type, placements.CLIENTS))
        def next_fn(state, value, weight):
            # Client computation.
            weighted_value = intrinsics.federated_map(_mul, (value, weight))

            # Inner aggregations.
            value_output = value_sum_process.next(state['value_sum_process'],
                                                  weighted_value)
            weight_output = weight_sum_process.next(
                state['weight_sum_process'], weight)

            # Server computation.
            weighted_mean_value = intrinsics.federated_map(
                _div_no_nan if self._no_nan_division else _div,
                (value_output.result, weight_output.result))

            # Output preparation.
            state = collections.OrderedDict(
                value_sum_process=value_output.state,
                weight_sum_process=weight_output.state)
            measurements = collections.OrderedDict(
                mean_value=value_output.measurements,
                mean_weight=weight_output.measurements)
            return measured_process.MeasuredProcessOutput(
                intrinsics.federated_zip(state), weighted_mean_value,
                intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
Exemple #12
0
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:
        py_typecheck.check_type(value_type, factory.ValueType.__args__)

        query_initial_state_fn = computations.tf_computation(
            self._query.initial_global_state)

        query_state_type = query_initial_state_fn.type_signature.result
        derive_sample_params = computations.tf_computation(
            self._query.derive_sample_params, query_state_type)
        get_query_record = computations.tf_computation(
            self._query.preprocess_record,
            derive_sample_params.type_signature.result, value_type)
        query_record_type = get_query_record.type_signature.result
        get_noised_result = computations.tf_computation(
            self._query.get_noised_result, query_record_type, query_state_type)
        derive_metrics = computations.tf_computation(
            self._query.derive_metrics, query_state_type)

        record_agg_process = self._record_aggregation_factory.create(
            query_record_type)

        @computations.federated_computation()
        def init_fn():
            return intrinsics.federated_zip(
                (intrinsics.federated_eval(query_initial_state_fn,
                                           placements.SERVER),
                 record_agg_process.initialize()))

        @computations.federated_computation(init_fn.type_signature.result,
                                            computation_types.FederatedType(
                                                value_type,
                                                placements.CLIENTS))
        def next_fn(state, value):
            query_state, agg_state = state

            params = intrinsics.federated_broadcast(
                intrinsics.federated_map(derive_sample_params, query_state))
            record = intrinsics.federated_map(get_query_record,
                                              (params, value))

            (new_agg_state, agg_result,
             agg_measurements) = record_agg_process.next(agg_state, record)

            result, new_query_state = intrinsics.federated_map(
                get_noised_result, (agg_result, query_state))

            query_metrics = intrinsics.federated_map(derive_metrics,
                                                     new_query_state)

            new_state = (new_query_state, new_agg_state)
            measurements = collections.OrderedDict(
                dp_query_metrics=query_metrics, dp=agg_measurements)
            return measured_process.MeasuredProcessOutput(
                intrinsics.federated_zip(new_state), result,
                intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
 def test_single_element_selection_leaves_no_unbound_references(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.StructType(
         [fed_at_clients, fed_at_server])
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Selection(building_blocks.Reference(
             'x', tuple_of_federated_types),
                                   index=0))
     new_lam = form_utils._as_function_of_some_federated_subparameters(
         lam, [(0, )])
     unbound_references = transformation_utils.get_map_of_unbound_references(
         new_lam)[new_lam]
     self.assertEmpty(unbound_references)
Exemple #14
0
 def test_passes_unbound_type_signature_obscured_under_block(self):
     fed_ref = building_blocks.Reference(
         'x', computation_types.FederatedType(tf.int32, placements.SERVER))
     block = building_blocks.Block(
         [('y', fed_ref), ('x', building_blocks.Data('whimsy', tf.int32)),
          ('z', building_blocks.Reference('x', tf.int32))],
         building_blocks.Reference('y', fed_ref.type_signature))
     tree_transformations.strip_placement(block)
  def test_process_type_signature(self, private):
    if private:
      quantile_estimator_query = tfp.QuantileEstimatorQuery(
          initial_estimate=1.0,
          target_quantile=0.5,
          learning_rate=1.0,
          below_estimate_stddev=0.5,
          expected_num_records=100,
          geometric_update=True)
    else:
      quantile_estimator_query = tfp.NoPrivacyQuantileEstimatorQuery(
          initial_estimate=1.0,
          target_quantile=0.5,
          learning_rate=1.0,
          geometric_update=True)

    process = QEProcess(quantile_estimator_query)

    query_state = quantile_estimator_query.initial_global_state()
    sum_process_state = ()

    server_state_type = computation_types.FederatedType(
        type_conversions.type_from_tensors((query_state, sum_process_state)),
        placements.SERVER)

    self.assertEqual(
        computation_types.FunctionType(
            parameter=None, result=server_state_type),
        process.initialize.type_signature)

    estimate_type = computation_types.FederatedType(tf.float32,
                                                    placements.SERVER)

    self.assertEqual(
        computation_types.FunctionType(
            parameter=server_state_type, result=estimate_type),
        process.report.type_signature)

    client_value_type = computation_types.FederatedType(tf.float32,
                                                        placements.CLIENTS)
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(
            computation_types.FunctionType(
                parameter=collections.OrderedDict(
                    state=server_state_type, value=client_value_type),
                result=server_state_type)))
Exemple #16
0
 def test_federated_init_state_not_assignable(self):
     zero = lambda: intrinsics.federated_value(0, placements.SERVER)
     initialize_fn = computations.federated_computation()(zero)
     next_fn = computations.federated_computation(
         computation_types.FederatedType(tf.int32, placements.CLIENTS))(
             lambda state: MeasuredProcessOutput(state, zero(), zero()))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         measured_process.MeasuredProcess(initialize_fn, next_fn)
 def test_adds_list_length_as_cardinality_at_clients(self):
     federated_type = computation_types.FederatedType(tf.int32,
                                                      placements.CLIENTS,
                                                      all_equal=False)
     five_clients = list(range(5))
     five_client_cardinalities = cardinalities_utils.infer_cardinalities(
         five_clients, federated_type)
     self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5)
Exemple #18
0
    def test_init_raises_value_error_with_datasets_empty(self):
        datasets = []
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)

        with self.assertRaises(ValueError):
            native_platform.DatasetDataSourceIterator(
                datasets=datasets, federated_type=federated_type)
    def test_federated_sum_in_xla_execution_context(self):
        @computations.federated_computation(
            computation_types.FederatedType(np.int32, placements.CLIENTS))
        def comp(x):
            return intrinsics.federated_sum(x)

        execution_contexts.set_local_execution_context()
        self.assertEqual(comp([1, 2, 3]), 6)
    def test_unweighted_federated_mean_in_xla_execution_context(self):
        @computations.federated_computation(
            computation_types.FederatedType(np.float32, placements.CLIENTS))
        def comp(x):
            return intrinsics.federated_mean(x)

        execution_contexts.set_local_execution_context()
        self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0)
Exemple #21
0
 def test_ensure_federated_value_wrong_placement(self):
     @computations.federated_computation(
         computation_types.FederatedType(tf.int32, placements.CLIENTS))
     def _(x):
         x = value_impl.to_value(x, None, _context_stack)
         with self.assertRaises(TypeError):
             value_utils.ensure_federated_value(x, placements.SERVER)
         return x
Exemple #22
0
 def _normalize_lambda_bit(comp):
     if not comp.parameter_type.is_federated():
         return comp, False
     return building_blocks.Lambda(
         comp.parameter_name,
         computation_types.FederatedType(comp.parameter_type.member,
                                         comp.parameter_type.placement),
         comp.result), True
 def test_converts_all_equal_at_clients_lambda_parameter_to_not_equal(self):
   fed_type_all_equal = computation_types.FederatedType(
       tf.int32, placements.CLIENTS, all_equal=True)
   normalized_fed_type = computation_types.FederatedType(
       tf.int32, placements.CLIENTS)
   ref = building_blocks.Reference('x', fed_type_all_equal)
   lam = building_blocks.Lambda('x', fed_type_all_equal, ref)
   normalized_lambda = transformations.normalize_all_equal_bit(lam)
   self.assertEqual(
       lam.type_signature,
       computation_types.FunctionType(fed_type_all_equal, fed_type_all_equal))
   self.assertIsInstance(normalized_lambda, building_blocks.Lambda)
   self.assertEqual(str(normalized_lambda), '(x -> x)')
   self.assertEqual(
       normalized_lambda.type_signature,
       computation_types.FunctionType(normalized_fed_type,
                                      normalized_fed_type))
 def test_federated_init_state_not_assignable(self):
     initialize_fn = computations.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = computations.federated_computation(
         computation_types.FederatedType(
             tf.int32, placements.CLIENTS))(lambda state: state)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(initialize_fn, next_fn)
Exemple #25
0
    def test_init_sets_federated_type(self, tensors, dtype):
        datasets = [tf.data.Dataset.from_tensor_slices(tensors)] * 3

        data_source = native_platform.DatasetDataSource(datasets=datasets)

        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(dtype), placements.CLIENTS)
        self.assertEqual(data_source.federated_type, federated_type)
 def test_passes_federated_type_tuple(self):
   tup = tuple(range(5))
   federated_type = computation_types.FederatedType(
       tf.int32, placements.CLIENTS, all_equal=False)
   cardinalities_utils.infer_cardinalities(tup, federated_type)
   five_client_cardinalities = cardinalities_utils.infer_cardinalities(
       tup, federated_type)
   self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5)
def get_iterative_process_for_sum_example_with_no_federated_secure_sum_bitwidth(
):
  """Returns an iterative process for a sum example.

  This iterative process does not have a call to
  `federated_secure_sum_bitwidth`.
  """

  @computations.federated_computation
  def init_fn():
    """The `init` function for `tff.templates.IterativeProcess`."""
    return intrinsics.federated_value(0, placements.SERVER)

  @computations.tf_computation(tf.int32)
  def prepare(server_state):
    return server_state

  @computations.tf_computation(tf.int32, tf.int32)
  def work(client_data, client_input):
    del client_data  # Unused
    del client_input  # Unused
    return 1

  @computations.tf_computation([tf.int32, tf.int32])
  def update(server_state, global_update):
    del server_state  # Unused
    return global_update, []

  @computations.federated_computation([
      computation_types.FederatedType(tf.int32, placements.SERVER),
      computation_types.FederatedType(tf.int32, placements.CLIENTS),
  ])
  def next_fn(server_state, client_data):
    """The `next` function for `tff.templates.IterativeProcess`."""
    s2 = intrinsics.federated_map(prepare, server_state)
    client_input = intrinsics.federated_broadcast(s2)
    c3 = intrinsics.federated_zip([client_data, client_input])
    client_updates = intrinsics.federated_map(work, c3)
    unsecure_update = intrinsics.federated_sum(client_updates)
    # No call to `federated_secure_sum_bitwidth`.
    s6 = intrinsics.federated_zip([server_state, unsecure_update])
    new_server_state, server_output = intrinsics.federated_map(update, s6)
    return new_server_state, server_output

  return iterative_process.IterativeProcess(init_fn, next_fn)
 def test_adds_list_length_as_cardinality_at_new_placement(self):
   new_placement = placements.PlacementLiteral('Agg', 'Agg', False,
                                               'Intermediate aggregators')
   federated_type = computation_types.FederatedType(
       tf.int32, new_placement, all_equal=False)
   ten_aggregators = list(range(10))
   ten_aggregator_cardinalities = cardinalities_utils.infer_cardinalities(
       ten_aggregators, federated_type)
   self.assertEqual(ten_aggregator_cardinalities[new_placement], 10)
 def test_selects_single_federated_output_by_str_name(self):
   fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
   ref = building_blocks.Reference('x', [('a', fed_type)])
   lam = building_blocks.Lambda('x', ref.type_signature, ref)
   selected = transformations.select_output_from_lambda(lam, 'a')
   self.assert_types_equivalent(
       selected.type_signature,
       computation_types.FunctionType(lam.parameter_type,
                                      lam.type_signature.result['a']))
 def test_infer_cardinalities_success_structure(self):
   foo = cardinalities_utils.infer_cardinalities(
       structure.Struct([('A', [1, 2, 3]),
                         ('B',
                          structure.Struct([('C', [[1, 2], [3, 4], [5, 6]]),
                                            ('D', [True, False, True])]))]),
       computation_types.StructType([
           ('A', computation_types.FederatedType(tf.int32,
                                                 placements.CLIENTS)),
           ('B', [('C',
                   computation_types.FederatedType(
                       computation_types.SequenceType(tf.int32),
                       placements.CLIENTS)),
                  ('D',
                   computation_types.FederatedType(tf.bool,
                                                   placements.CLIENTS))])
       ]))
   self.assertDictEqual(foo, {placements.CLIENTS: 3})