示例#1
0
def _build_test_map_reduce_form_with_computations(initialize=None,
                                                  prepare=None,
                                                  work=None,
                                                  zero=None,
                                                  accumulate=None,
                                                  merge=None,
                                                  report=None,
                                                  bitwidth=None,
                                                  max_input=None,
                                                  modulus=None,
                                                  update=None):
    (test_initialize, test_prepare, test_work, test_zero, test_accumulate,
     test_merge, test_report, test_bitwidth, test_max_input, test_modulus,
     test_update) = _test_map_reduce_form_computations()
    return forms.MapReduceForm(
        initialize if initialize else test_initialize,
        prepare if prepare else test_prepare,
        work if work else test_work,
        zero if zero else test_zero,
        accumulate if accumulate else test_accumulate,
        merge if merge else test_merge,
        report if report else test_report,
        bitwidth if bitwidth else test_bitwidth,
        max_input if max_input else test_max_input,
        modulus if modulus else test_modulus,
        update if update else test_update,
    )
示例#2
0
def get_map_reduce_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.MapReduceForm:
    """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible with
      MapReduce form. Iterative processes are only compatible if `initialize_fn`
      returns a single federated value placed at `SERVER` and `next` takes
      exactly two arguments. The first must be the state value placed at
      `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.MapReduceForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      the input `grappler_config` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the
    provided `tff.templates.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.MapReduceFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_bb, next_bb = (
        check_iterative_process_compatible_with_map_reduce_form(ip))
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    next_bb, _ = tree_transformations.uniquify_reference_names(next_bb)
    before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb)
    before_aggregate, after_aggregate = _split_ast_on_aggregate(
        after_broadcast)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_bb, grappler_config)
    prepare = _extract_prepare(before_broadcast, grappler_config)
    work = _extract_work(before_aggregate, grappler_config)
    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    bitwidth = _extract_federated_secure_sum_bitwidth_functions(
        before_aggregate, grappler_config)
    update = _extract_update(after_aggregate, grappler_config)

    next_parameter_names = structure.name_list_with_nones(
        ip.next.type_signature.parameter)
    server_state_label, client_data_label = next_parameter_names
    comps = (computation_wrapper_instances.building_block_to_computation(bb)
             for bb in (initialize, prepare, work, zero, accumulate, merge,
                        report, bitwidth, update))
    return forms.MapReduceForm(*comps,
                               server_state_label=server_state_label,
                               client_data_label=client_data_label)
示例#3
0
    def test_init_does_not_raise_type_error(self):
        (initialize, prepare, work, zero, accumulate, merge, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        try:
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')
示例#4
0
    def test_init_raises_type_error_with_bad_zero_result_type(self):
        (initialize, prepare, work, _, accumulate, merge, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation
        def zero():
            return tf.constant(0.0), tf.constant(0)

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#5
0
    def test_init_raises_type_error_with_bad_report_result_type(self):
        (initialize, prepare, work, zero, accumulate, merge, _, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation(tf.int32, tf.int32)
        def report(accumulator):
            del accumulator  # Unused
            return tf.constant(1)

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#6
0
    def test_init_raises_type_error_with_bad_prepare_parameter_type(self):
        (initialize, _, work, zero, accumulate, merge, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation(tf.float32)
        def prepare(server_state):
            del server_state  # Unused
            return tf.constant(1.0)

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#7
0
    def test_init_raises_type_error_with_bad_accumulate_result_type(self):
        (initialize, prepare, work, zero, _, merge, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation((tf.float32, tf.float32), tf.bool)
        def accumulate(accumulator, client_update):
            del accumulator  # Unused
            del client_update  # Unused
            return tf.constant(1.0), tf.constant(1)

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#8
0
    def test_init_raises_type_error_with_bad_update_result_type(self):
        (initialize, prepare, work, zero, accumulate, merge, report, bitwidth,
         _) = _test_map_reduce_form_computations()

        @computations.tf_computation(
            tf.int32, (tf.float32, computation_types.StructType([])))
        def update(server_state, global_update):
            del server_state  # Unused
            del global_update  # Unused
            return tf.constant(1.0), []

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#9
0
    def test_init_raises_type_error_with_bad_merge_second_parameter_type(self):
        (initialize, prepare, work, zero, accumulate, _, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation((tf.int32, tf.int32),
                                     (tf.float32, tf.int32))
        def merge(accumulator1, accumulator2):
            del accumulator1  # Unused
            del accumulator2  # Unused
            return tf.constant(1), tf.constant(1)

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#10
0
    def test_init_raises_type_error_with_bad_work_result_type(self):
        (initialize, prepare, _, zero, accumulate, merge, report, bitwidth,
         update) = _test_map_reduce_form_computations()

        @computations.tf_computation(computation_types.SequenceType(
            tf.float32), tf.float32)
        def work(client_data, client_input):
            del client_data  # Unused
            del client_input  # Unused
            return tf.constant('abc'), []

        with self.assertRaises(TypeError):
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, update)
示例#11
0
    def test_init_does_not_raise_type_error_with_unknown_dimensions(self):
        server_state_type = computation_types.TensorType(shape=[None],
                                                         dtype=tf.int32)

        @tensorflow_computation.tf_computation
        def initialize():
            # Return a value of a type assignable to, but not equal to
            # `server_state_type`
            return tf.constant([1, 2, 3])

        @tensorflow_computation.tf_computation(server_state_type)
        def prepare(server_state):
            del server_state  # Unused
            return tf.constant(1.0)

        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.float32), tf.float32)
        def work(client_data, client_input):
            del client_data  # Unused
            del client_input  # Unused
            return True, [], [], []

        @tensorflow_computation.tf_computation
        def zero():
            return tf.constant([], dtype=tf.string)

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string),
            tf.bool)
        def accumulate(accumulator, client_update):
            del accumulator  # Unused
            del client_update  # Unused
            return tf.constant(['abc'])

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string),
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def merge(accumulator1, accumulator2):
            del accumulator1  # Unused
            del accumulator2  # Unused
            return tf.constant(['abc'])

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def report(accumulator):
            del accumulator  # Unused
            return tf.constant(1.0)

        unit_comp = tensorflow_computation.tf_computation(lambda: [])
        bitwidth = unit_comp
        max_input = unit_comp
        modulus = unit_comp
        unit_type = computation_types.to_type([])

        @tensorflow_computation.tf_computation(
            server_state_type, (tf.float32, unit_type, unit_type, unit_type))
        def update(server_state, global_update):
            del server_state  # Unused
            del global_update  # Unused
            # Return a new server state value whose type is assignable but not equal
            # to `server_state_type`, and which is different from the type returned
            # by `initialize`.
            return tf.constant([1]), []

        try:
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, max_input, modulus,
                                update)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')
示例#12
0
def get_temperature_sensor_example():
    """Constructs `forms.MapReduceForm` for temperature sensors example.

  The temperature sensor example computes the fraction of sensors that report
  temperatures over the threshold.

  Returns:
    An instance of `forms.MapReduceForm`.
  """
    @tensorflow_computation.tf_computation
    def initialize():
        return collections.OrderedDict(num_rounds=tf.constant(0))

    # The state of the server is a singleton tuple containing just the integer
    # counter `num_rounds`.
    server_state_type = collections.OrderedDict(num_rounds=tf.int32)

    @tensorflow_computation.tf_computation(server_state_type)
    def prepare(state):
        return collections.OrderedDict(
            max_temperature=32.0 + tf.cast(state['num_rounds'], tf.float32))

    # The initial state of the client is a singleton tuple containing a single
    # float `max_temperature`, which is the threshold received from the server.
    client_state_type = collections.OrderedDict(max_temperature=tf.float32)

    # The client data is a sequence of floats.
    client_data_type = computation_types.SequenceType(tf.float32)

    @tensorflow_computation.tf_computation(client_data_type, client_state_type)
    def work(data, state):
        """See the `forms.MapReduceForm` definition of `work`."""
        def fn(s, x):
            return {
                'num': s['num'] + 1,
                'max': tf.maximum(s['max'], x),
            }

        reduce_result = data.reduce(
            {
                'num': np.int32(0),
                'max': np.float32(-459.67)
            }, fn)
        client_updates = collections.OrderedDict(
            is_over=reduce_result['max'] > state['max_temperature'])
        return client_updates, [], [], []

    # The client update is a singleton tuple with a Boolean-typed `is_over`.
    client_update_type = collections.OrderedDict(is_over=tf.bool)

    # The accumulator for client updates is a pair of counters, one for the
    # number of clients over threshold, and the other for the total number of
    # client updates processed so far.
    accumulator_type = collections.OrderedDict(num_total=tf.int32,
                                               num_over=tf.int32)

    @tensorflow_computation.tf_computation
    def zero():
        return collections.OrderedDict(num_total=tf.constant(0),
                                       num_over=tf.constant(0))

    @tensorflow_computation.tf_computation(accumulator_type,
                                           client_update_type)
    def accumulate(accumulator, update):
        return collections.OrderedDict(num_total=accumulator['num_total'] + 1,
                                       num_over=accumulator['num_over'] +
                                       tf.cast(update['is_over'], tf.int32))

    @tensorflow_computation.tf_computation(accumulator_type, accumulator_type)
    def merge(accumulator1, accumulator2):
        return collections.OrderedDict(
            num_total=accumulator1['num_total'] + accumulator2['num_total'],
            num_over=accumulator1['num_over'] + accumulator2['num_over'])

    @tensorflow_computation.tf_computation(merge.type_signature.result)
    def report(accumulator):
        return collections.OrderedDict(ratio_over_threshold=(
            tf.cast(accumulator['num_over'], tf.float32) /
            tf.cast(accumulator['num_total'], tf.float32)))

    unit_comp = tensorflow_computation.tf_computation(lambda: [])
    bitwidth = unit_comp
    max_input = unit_comp
    modulus = unit_comp

    update_type = (collections.OrderedDict(ratio_over_threshold=tf.float32),
                   (), (), ())

    @tensorflow_computation.tf_computation(server_state_type, update_type)
    def update(state, update):
        return (collections.OrderedDict(num_rounds=state['num_rounds'] + 1),
                update[0])

    return forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                               merge, report, bitwidth, max_input, modulus,
                               update)
示例#13
0
def get_mnist_training_example():
    """Constructs `forms.MapReduceForm` for mnist training.

  Returns:
    An instance of `forms.MapReduceForm`.
  """
    model_nt = collections.namedtuple('Model', 'weights bias')
    server_state_nt = (collections.namedtuple('ServerState',
                                              'model num_rounds'))

    # Start with a model filled with zeros, and the round counter set to zero.
    @tensorflow_computation.tf_computation
    def initialize():
        return server_state_nt(model=model_nt(weights=tf.zeros([784, 10]),
                                              bias=tf.zeros([10])),
                               num_rounds=tf.constant(0))

    server_state_tff_type = server_state_nt(model=model_nt(weights=(tf.float32,
                                                                    [784, 10]),
                                                           bias=(tf.float32,
                                                                 [10])),
                                            num_rounds=tf.int32)
    client_state_nt = (collections.namedtuple('ClientState',
                                              'model learning_rate'))

    # Pass the model to the client, along with a dynamically adjusted learning
    # rate that starts at 0.1 and decays exponentially by a factor of 0.9.
    @tensorflow_computation.tf_computation(server_state_tff_type)
    def prepare(state):
        learning_rate = 0.1 * tf.pow(0.9, tf.cast(state.num_rounds,
                                                  tf.float32))
        return client_state_nt(model=state.model, learning_rate=learning_rate)

    batch_nt = collections.namedtuple('Batch', 'x y')
    batch_tff_type = batch_nt(x=(tf.float32, [None, 784]),
                              y=(tf.int32, [None]))
    dataset_tff_type = computation_types.SequenceType(batch_tff_type)
    model_tff_type = model_nt(weights=(tf.float32, [784, 10]),
                              bias=(tf.float32, [10]))
    client_state_tff_type = client_state_nt(model=model_tff_type,
                                            learning_rate=tf.float32)
    loop_state_nt = collections.namedtuple('LoopState',
                                           'num_examples total_loss')
    update_nt = collections.namedtuple('Update', 'model num_examples loss')

    # Train the model locally, emit the loclaly-trained model and the number of
    # examples as an update, and the average loss and the number of examples as
    # local client stats.
    @tensorflow_computation.tf_computation(dataset_tff_type,
                                           client_state_tff_type)
    def work(data, state):  # pylint: disable=missing-docstring
        model_vars = model_nt(weights=tf.Variable(
            initial_value=state.model.weights, name='weights'),
                              bias=tf.Variable(initial_value=state.model.bias,
                                               name='bias'))
        init_model = tf.compat.v1.global_variables_initializer()

        optimizer = tf.keras.optimizers.SGD(state.learning_rate)

        @tf.function
        def reduce_fn(loop_state, batch):
            """Compute a single gradient step on an given batch of examples."""
            with tf.GradientTape() as tape:
                pred_y = tf.nn.softmax(
                    tf.matmul(batch.x, model_vars.weights) + model_vars.bias)
                loss = -tf.reduce_mean(
                    tf.reduce_sum(
                        tf.one_hot(batch.y, 10) * tf.math.log(pred_y),
                        axis=[1]))
            grads = tape.gradient(loss, model_vars)
            optimizer.apply_gradients(
                zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
            return loop_state_nt(num_examples=loop_state.num_examples + 1,
                                 total_loss=loop_state.total_loss + loss)

        with tf.control_dependencies([init_model]):
            loop_state = data.reduce(
                loop_state_nt(num_examples=0, total_loss=np.float32(0.0)),
                reduce_fn)
            num_examples = loop_state.num_examples
            total_loss = loop_state.total_loss
            with tf.control_dependencies([num_examples, total_loss]):
                loss = total_loss / tf.cast(num_examples, tf.float32)

        return update_nt(model=model_vars,
                         num_examples=num_examples,
                         loss=loss), [], [], []

    accumulator_nt = update_nt

    # Initialize accumulators for aggregation with zero model and zero examples.
    @tensorflow_computation.tf_computation
    def zero():
        return accumulator_nt(model=model_nt(weights=tf.zeros([784, 10]),
                                             bias=tf.zeros([10])),
                              num_examples=tf.constant(0),
                              loss=tf.constant(0.0, dtype=tf.float32))

    update_tff_type = update_nt(model=model_tff_type,
                                num_examples=tf.int32,
                                loss=tf.float32)
    accumulator_tff_type = update_tff_type

    # We add an update to an accumulator with the update's model multipled by the
    # number of examples, so we can compute a weighted average in the end.
    @tensorflow_computation.tf_computation(accumulator_tff_type,
                                           update_tff_type)
    def accumulate(accumulator, update):
        scaling_factor = tf.cast(update.num_examples, tf.float32)
        scaled_model = tf.nest.map_structure(lambda x: x * scaling_factor,
                                             update.model)
        return accumulator_nt(
            model=tf.nest.map_structure(tf.add, accumulator.model,
                                        scaled_model),
            num_examples=accumulator.num_examples + update.num_examples,
            loss=accumulator.loss + update.loss * scaling_factor)

    # Merging accumulators does not involve scaling.
    @tensorflow_computation.tf_computation(accumulator_tff_type,
                                           accumulator_tff_type)
    def merge(accumulator1, accumulator2):
        return accumulator_nt(
            model=tf.nest.map_structure(tf.add, accumulator1.model,
                                        accumulator2.model),
            num_examples=accumulator1.num_examples + accumulator2.num_examples,
            loss=accumulator1.loss + accumulator2.loss)

    report_nt = accumulator_nt

    # The result of aggregation is produced by dividing the accumulated model by
    # the total number of examples. Same for loss.
    @tensorflow_computation.tf_computation(accumulator_tff_type)
    def report(accumulator):
        scaling_factor = 1.0 / tf.cast(accumulator.num_examples, tf.float32)
        scaled_model = model_nt(weights=accumulator.model.weights *
                                scaling_factor,
                                bias=accumulator.model.bias * scaling_factor)
        return report_nt(model=scaled_model,
                         num_examples=accumulator.num_examples,
                         loss=accumulator.loss * scaling_factor)

    unit_computation = tensorflow_computation.tf_computation(lambda: [])
    secure_sum_bitwidth = unit_computation
    secure_sum_max_input = unit_computation
    secure_sum_modulus = unit_computation

    update_type = (accumulator_tff_type, (), (), ())
    metrics_nt = collections.namedtuple('Metrics',
                                        'num_rounds num_examples loss')

    # Pass the newly averaged model along with an incremented round counter over
    # to the next round, and output the counters and loss as server metrics.
    @tensorflow_computation.tf_computation(server_state_tff_type, update_type)
    def update(state, update):
        report = update[0]
        num_rounds = state.num_rounds + 1
        return (server_state_nt(model=report.model, num_rounds=num_rounds),
                metrics_nt(num_rounds=num_rounds,
                           num_examples=report.num_examples,
                           loss=report.loss))

    return forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                               merge, report, secure_sum_bitwidth,
                               secure_sum_max_input, secure_sum_modulus,
                               update)
示例#14
0
def get_federated_sum_example(*,
                              secure_sum: bool = False) -> forms.MapReduceForm:
    """Constructs `forms.MapReduceForm` which performs a sum aggregation.

  Args:
    secure_sum: Whether to use `federated_secure_sum_bitwidth`. Defaults to
      `federated_sum`.

  Returns:
    An instance of `forms.MapReduceForm`.
  """
    @tensorflow_computation.tf_computation
    def initialize():
        return ()

    server_state_type = initialize.type_signature.result

    @tensorflow_computation.tf_computation(server_state_type)
    def prepare(state):
        return state

    @tensorflow_computation.tf_computation(
        computation_types.SequenceType(tf.int32),
        prepare.type_signature.result)
    def work(data, _):
        client_sum = data.reduce(initial_state=0, reduce_func=tf.add)
        if secure_sum:
            return [], client_sum, [], []
        else:
            return client_sum, [], [], []

    @tensorflow_computation.tf_computation
    def zero():
        if secure_sum:
            return ()
        else:
            return 0

    client_update_type = work.type_signature.result[0]
    accumulator_type = zero.type_signature.result

    @tensorflow_computation.tf_computation(accumulator_type,
                                           client_update_type)
    def accumulate(accumulator, update):
        if secure_sum:
            return ()
        else:
            return accumulator + update

    @tensorflow_computation.tf_computation(accumulator_type, accumulator_type)
    def merge(accumulator1, accumulator2):
        if secure_sum:
            return ()
        else:
            return accumulator1 + accumulator2

    @tensorflow_computation.tf_computation(merge.type_signature.result)
    def report(accumulator):
        return accumulator

    bitwidth = tensorflow_computation.tf_computation(lambda: 32)
    max_input = tensorflow_computation.tf_computation(lambda: 0)
    modulus = tensorflow_computation.tf_computation(lambda: 0)

    update_type = (
        merge.type_signature.result,
        work.type_signature.result[1],
        work.type_signature.result[2],
        work.type_signature.result[3],
    )

    @tensorflow_computation.tf_computation(server_state_type, update_type)
    def update(state, update):
        if secure_sum:
            return state, update[1]
        else:
            return state, update[0]

    return forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                               merge, report, bitwidth, max_input, modulus,
                               update)