def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) dependent_nodes = set() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if isinstance( comp, (building_blocks.Intrinsic, building_blocks.Data, building_blocks.Placement, building_blocks.CompiledComputation)): return False elif isinstance(comp, building_blocks.Lambda): return comp.result in dependent_nodes elif isinstance(comp, building_blocks.Block): return any(x[1] in dependent_nodes for x in comp.locals) or comp.result in dependent_nodes elif isinstance(comp, building_blocks.Tuple): return any(x in dependent_nodes for x in comp) elif isinstance(comp, building_blocks.Selection): return comp.source in dependent_nodes elif isinstance(comp, building_blocks.Call): return comp.function in dependent_nodes or comp.argument in dependent_nodes elif isinstance(comp, building_blocks.Reference): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): payload = symbol_tree.get_payload_with_name(comp.name) if payload is None: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return payload.value in dependent_nodes def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes
def to_representation_for_type(value, type_spec=None, device=None): """Verifies or converts the `value` to an eager objct matching `type_spec`. WARNING: This function is only partially implemented. It does not support data sets at this point. The output of this function is always an eager tensor, eager dataset, a representation of a TensorFlow computtion, or a nested structure of those that matches `type_spec`, and when `device` has been specified, everything is placed on that device on a best-effort basis. TensorFlow computations are represented here as zero- or one-argument Python callables that accept their entire argument bundle as a single Python object. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. type_spec: An instance of `tff.Type`, can be `None` for values that derive from `typed_object.TypedObject`. device: The optional device to place the value on (for tensor-level values). Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. """ if device is not None: py_typecheck.check_type(device, six.string_types) with tf.device(device): return to_representation_for_type(value, type_spec=type_spec, device=None) type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec) if isinstance(value, EagerValue): return value.internal_representation if isinstance(value, executor_value_base.ExecutorValue): raise TypeError( 'Cannot accept a value embedded within a non-eager executor.') if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), type_spec, device) if isinstance(value, pb.Computation): return embed_tensorflow_computation(value, type_spec, device) if isinstance(type_spec, computation_types.TensorType): if not isinstance(value, tf.Tensor): if isinstance(value, np.ndarray): value = tf.constant(value, dtype=type_spec.dtype) else: value = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) value_type = ( computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_utils.is_assignable_from(type_spec, value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(str(value_type), str(value), str(type_spec))) return value elif isinstance(type_spec, computation_types.NamedTupleType): type_elem = anonymous_tuple.to_elements(type_spec) value_elem = ( anonymous_tuple.to_elements(anonymous_tuple.from_container(value))) result_elem = [] if len(type_elem) != len(value_elem): raise TypeError('Expected a {}-element tuple, found {} elements.'.format( str(len(type_elem)), str(len(value_elem)))) for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem): if t_name != v_name: raise TypeError( 'Mismatching element names in type vs. value: {} vs. {}.'.format( t_name, v_name)) el_repr = to_representation_for_type(el_val, el_type, device) result_elem.append((t_name, el_repr)) return anonymous_tuple.AnonymousTuple(result_elem) elif isinstance(type_spec, computation_types.SequenceType): if isinstance(value, list): value = graph_utils.make_data_set_from_elements(None, value, type_spec.element) py_typecheck.check_type( value, (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)) element_type = type_utils.tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(value), tf.compat.v1.data.get_output_shapes(value)) value_type = computation_types.SequenceType(element_type) type_utils.check_assignable_from(type_spec, value_type) return value else: raise TypeError('Unexpected type {}.'.format(str(type_spec)))
def __init__(self, identifier): py_typecheck.check_type(identifier, str) self._identifier = identifier
def zeroing_factory(zeroing_norm: Union[float, estimation_process.EstimationProcess], inner_agg_factory: factory.AggregationFactory, norm_order: float = math.inf) -> factory.AggregationFactory: """Creates an aggregation factory to perform zeroing. The created `tff.templates.AggregationProcess` zeroes out any values whose norm is greater than that determined by the provided `zeroing_norm`, before aggregating the values as specified by `inner_agg_factory`. Note that for weighted aggregation if some value is zeroed, the weight is unchanged. So for example if you have a zeroed weighted mean and a lot of zeroing occurs, the average will tend to be pulled toward zero. This is for consistency between weighted and unweighted aggregation The provided `zeroing_norm` can either be a constant (for fixed norm), or an instance of `tff.templates.EstimationProcess` (for adaptive norm). If it is an estimation process, the value returned by its `report` method will be used as the zeroing norm. Its `next` method needs to accept a scalar float32 at clients, corresponding to the norm of value being aggregated. The process can thus adaptively determine the zeroing norm based on the set of aggregated values. For example if a `tff.aggregators.PrivateQuantileEstimationProcess` is used, the zeroing norm will be an estimate of a quantile of the norms of the values being aggregated. The returned `AggregationFactory` takes its weightedness (`UnweightedAggregationFactory` vs. `WeightedAggregationFactory`) from `inner_agg_factory`. Args: zeroing_norm: Either a float (for fixed norm) or an `EstimationProcess` (for adaptive norm) that specifies the norm over which the values should be zeroed. inner_agg_factory: A factory specifying the type of aggregation to be done after zeroing. norm_order: A float for the order of the norm. Must be 1., 2., or infinity. Returns: An aggregation factory to perform L2 clipping. """ py_typecheck.check_type(norm_order, float) if not (norm_order in [1.0, 2.0] or math.isinf(norm_order)): raise ValueError('norm_order must be 1.0, 2.0 or infinity') def make_zero_fn(value_type): """Creates a zeroing function for the value_type.""" @computations.tf_computation(value_type, NORM_TF_TYPE) def zero_fn(value, zeroing_norm): if norm_order == 1.0: global_norm = _global_l1_norm(value) elif norm_order == 2.0: global_norm = tf.linalg.global_norm(tf.nest.flatten(value)) else: assert math.isinf(norm_order) global_norm = _global_inf_norm(value) should_zero = (global_norm > zeroing_norm) zeroed_value = tf.cond( should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value), lambda: value) was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE) return zeroed_value, global_norm, was_zeroed return zero_fn return _make_wrapper(zeroing_norm, inner_agg_factory, make_zero_fn, 'zero')
def _check_value_type(value_type): py_typecheck.check_type(value_type, factory.ValueType.__args__) if not type_analysis.is_structure_of_floats(value_type): raise TypeError(f'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}')
def append_to_list_structure_for_element_type_spec(nested, value, type_spec): """Adds an element `value` to `nested` lists for `type_spec`. This function appends tensor-level constituents of an element `value` to the lists created by `make_empty_list_structure_for_element_type_spec`. The nested structure of `value` must match that created by the above function, and consistent with `type_spec`. Args: nested: Output of `make_empty_list_structure_for_element_type_spec`. value: A value (Python object) that a hierarchical structure of dictionary, list, and other containers holding tensor-like items that matches the hierarchy of `type_spec`. type_spec: An instance of `tff.Type` or something convertible to it, as in `make_empty_list_structure_for_element_type_spec`. Raises: TypeError: If the `type_spec` is not of a form described above, or the value is not of a type compatible with `type_spec`. """ if value is None: return type_spec = computation_types.to_type(type_spec) # TODO(b/113116813): This could be made more efficient, but for now we won't # need to worry about it as this is an odd corner case. if isinstance(value, structure.Struct): elements = structure.to_elements(value) if all(k is not None for k, _ in elements): value = collections.OrderedDict(elements) elif all(k is None for k, _ in elements): value = tuple([v for _, v in elements]) else: raise TypeError( 'Expected an anonymous tuple to either have all elements named or ' 'all unnamed, got {}.'.format(value)) if type_spec.is_tensor(): py_typecheck.check_type(nested, list) # Convert the members to tensors to ensure that they are properly # typed and grouped before being passed to # tf.data.Dataset.from_tensor_slices. nested.append(tf.convert_to_tensor(value, type_spec.dtype)) # pytype: disable=attribute-error elif type_spec.is_struct(): elements = structure.to_elements(type_spec) if isinstance(nested, collections.OrderedDict): if py_typecheck.is_named_tuple(value): value = value._asdict() # pytype: disable=attribute-error if isinstance(value, dict): if set(value.keys()) != set(k for k, _ in elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for elem_name, elem_type in elements: append_to_list_structure_for_element_type_spec( nested[elem_name], value[elem_name], elem_type) elif isinstance(value, (list, tuple)): if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for idx, (elem_name, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec( nested[elem_name], value[idx], elem_type) else: raise TypeError( 'Unexpected type of value {} for TFF type {}.'.format( py_typecheck.type_string(type(value)), type_spec)) elif isinstance(nested, tuple): py_typecheck.check_type(value, (list, tuple)) if len(value) != len(elements): raise TypeError('Value {} does not match type {}.'.format( value, type_spec)) for idx, (_, elem_type) in enumerate(elements): append_to_list_structure_for_element_type_spec( nested[idx], value[idx], elem_type) else: raise TypeError( 'Invalid nested structure, unexpected container type {}.'. format(py_typecheck.type_string(type(nested)))) else: raise TypeError( 'Expected a tensor or named tuple type, found {}.'.format( type_spec))
def fetch_value_in_session(sess, value): """Fetches `value` in `session`. Args: sess: The session in which to perform the fetch (as a single run). value: A Python object of a form analogous to that constructed by the function `assemble_result_from_graph`, made of tensors and anononymous tuples, or a `tf.data.Dataset`. Returns: A Python object with structure similar to `value`, but with tensors replaced with their values, and data sets replaced with lists of their elements, all fetched with a single call `session.run()`. Raises: ValueError: If `value` is not a `tf.data.Dataset` or not a structure of tensors and anonoymous tuples. """ py_typecheck.check_type(sess, tf.compat.v1.Session) # TODO(b/113123634): Investigate handling `list`s and `tuple`s of # `tf.data.Dataset`s and what the API would look like to support this. if isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES): with sess.graph.as_default(): iterator = tf.compat.v1.data.make_one_shot_iterator(value) next_element = iterator.get_next() elements = [] while True: try: elements.append(sess.run(next_element)) except tf.errors.OutOfRangeError: break return elements else: flattened_value = structure.flatten(value) dataset_results = {} flat_tensors = [] for idx, v in enumerate(flattened_value): if isinstance(v, type_conversions.TF_DATASET_REPRESENTATION_TYPES): dataset_tensors = fetch_value_in_session(sess, v) if not dataset_tensors: # An empty list has been returned; we must pack the shape information # back in or the result won't typecheck. element_structure = v.element_spec dummy_elem = make_dummy_element_for_type_spec( element_structure) dataset_tensors = [dummy_elem] dataset_results[idx] = dataset_tensors elif tf.is_tensor(v): flat_tensors.append(v) else: raise ValueError('Unsupported value type {}.'.format(v)) # Note that `flat_tensors` could be an empty tuple, but it could also be a # list of empty tuples. if flat_tensors or any(x for x in flat_tensors): flat_computed_tensors = sess.run(flat_tensors) else: flat_computed_tensors = flat_tensors flattened_results = _interleave_dataset_results_and_tensors( dataset_results, flat_computed_tensors) def _to_unicode(v): if isinstance(v, bytes): return v.decode('utf-8') return v if tf.is_tensor(value) and value.dtype == tf.string: flattened_results = [ _to_unicode(result) for result in flattened_results ] return structure.pack_sequence_as(value, flattened_results)
def from_keras_model( keras_model: tf.keras.Model, loss: Loss, input_spec, loss_weights: Optional[List[float]] = None, metrics: Optional[List[tf.keras.metrics.Metric]] = None ) -> model_lib.Model: """Builds a `tff.learning.Model` from a `tf.keras.Model`. The `tff.learning.Model` returned by this function uses `keras_model` for its forward pass and autodifferentiation steps. Notice that since TFF couples the `tf.keras.Model` and `loss`, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model `M` takes inputs of type `x` and produces predictions of type `p`; the loss function `L` takes inputs of type `<p, y>` and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated `tff.learning.Model` will appear, TFF needs the type `y` in addition to the type `x`. Args: keras_model: A `tf.keras.Model` object that is not compiled. loss: A `tf.keras.losses.Loss`, or a list of losses-per-output if the model has multiple outputs. If multiple outputs are present, the model will attempt to minimize the sum of all individual losses (optionally weighted using the `loss_weights` argument). input_spec: A structure of `tf.TensorSpec`s or `tff.Type` specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model (x) to generate predictions as well as the expected type of the ground truth (y). If provided as a list, it must be in the order [x, y]. If provided as a dictionary, the keys must explicitly be named `'x'` and `'y'`. loss_weights: (Optional) A list of Python floats used to weight the loss contribution of each model output. metrics: (Optional) a list of `tf.keras.metrics.Metric` objects. Returns: A `tff.learning.Model` object. Raises: TypeError: If `keras_model` is not instance of `tf.keras.Model`, if `keras_model` has a single output and `loss` is not instance of `tf.keras.losses.Loss`, or if `keras_model` has multiple outputs and `loss` is not a list of instances of `tf.keras.losses.Loss`. ValueError: If `keras_model` was compiled, if `keras_model` has multiple outputs and `loss` is not list of equal length, if `input_spec` does not contain exactly two elements, or if `input_spec` is a dictionary and does not contain keys `'x'` and `'y'`. """ # Validate `keras_model` py_typecheck.check_type(keras_model, tf.keras.Model) if keras_model._is_compiled: # pylint: disable=protected-access raise ValueError('`keras_model` must not be compiled') # Validate and normalize `loss` and `loss_weights` if len(keras_model.outputs) == 1: py_typecheck.check_type(loss, tf.keras.losses.Loss) if loss_weights is not None: raise ValueError( '`loss_weights` cannot be used if `keras_model` has ' 'only one output.') loss = [loss] loss_weights = [1.0] else: py_typecheck.check_type(loss, list) if len(loss) != len(keras_model.outputs): raise ValueError('`keras_model` must have equal number of ' 'outputs and losses.\nloss: {}\nof length: {}.' '\noutputs: {}\nof length: {}.'.format( loss, len(loss), keras_model.outputs, len(keras_model.outputs))) for loss_fn in loss: py_typecheck.check_type(loss_fn, tf.keras.losses.Loss) if loss_weights is None: loss_weights = [1.0] * len(loss) else: if len(loss) != len(loss_weights): raise ValueError( '`keras_model` must have equal number of losses and loss_weights.' '\nloss: {}\nof length: {}.' '\nloss_weights: {}\nof length: {}.'.format( loss, len(loss), loss_weights, len(loss_weights))) for loss_weight in loss_weights: py_typecheck.check_type(loss_weight, float) if len(input_spec) != 2: raise ValueError( 'The top-level structure in `input_spec` must contain ' 'exactly two top-level elements, as it must specify type ' 'information for both inputs to and predictions from the ' 'model. You passed input spec {}.'.format(input_spec)) if not isinstance(input_spec, computation_types.Type): for input_spec_member in tf.nest.flatten(input_spec): py_typecheck.check_type(input_spec_member, tf.TensorSpec) else: for type_elem in input_spec: py_typecheck.check_type(type_elem, computation_types.TensorType) if isinstance(input_spec, collections.Mapping): if 'x' not in input_spec: raise ValueError( 'The `input_spec` is a collections.Mapping (e.g., a dict), so it ' 'must contain an entry with key `\'x\'`, representing the input(s) ' 'to the Keras model.') if 'y' not in input_spec: raise ValueError( 'The `input_spec` is a collections.Mapping (e.g., a dict), so it ' 'must contain an entry with key `\'y\'`, representing the label(s) ' 'to be used in the Keras loss(es).') if metrics is None: metrics = [] else: py_typecheck.check_type(metrics, list) for metric in metrics: py_typecheck.check_type(metric, tf.keras.metrics.Metric) return model_utils.enhance( _KerasModel(keras_model, input_spec=input_spec, loss_fns=loss, loss_weights=loss_weights, metrics=metrics))
def _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context=None): """The main `tf.function` that runs on device. This function first evalautes the initial model and gets the baseline metrics. Then starting from the same initial model, this function iterates over the personalization strategies defined in `personalize_fn_dict`, trains and evaluates the personalized models, and returns the evaluation metrics. Args: model: A `tff.learning.Model`. initial_model_weights: A `tff.learning.framework.ModelWeights` containing `tf.Tensor`s that hold trainable and non-trainable weights. train_data: A `tf.data.Dataset` used for training. test_data: A `tf.data.Dataset` used for evaluation. personalize_fn_dict: This is the same argument specified in the function `build_personalization_eval` above; see its documentation for details. baseline_evaluate_fn: This is the same argument specified in the function `build_personalization_eval` above; see its documentation for details. context: An optional object used in `personalize_fn_dict`. If used, its `tff.Type` must be provided by passing the correct `context_tff_type` argument to the `build_personalization_eval` function. Returns: An `OrderedDict` that maps a string 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. """ # Wrap the input model as an `EnhancedModel` for easy access of its weights. model = model_utils.enhance(model) final_metrics = collections.OrderedDict() tff.utils.assign(model.weights, initial_model_weights) py_typecheck.check_callable(baseline_evaluate_fn) final_metrics['baseline_metrics'] = baseline_evaluate_fn(model, test_data) py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict) if 'baseline_metrics' in personalize_fn_dict: raise ValueError('baseline_metrics should not be used as a key in ' 'personalize_fn_dict.') for name, personalize_fn_builder in personalize_fn_dict.items(): py_typecheck.check_type(name, str) tff.utils.assign(model.weights, initial_model_weights) # Construct the `personalize_fn` (and the associated `tf.Variable`s) here. # Once `_client_fn` is decorated with `tff.tf_computation`, construction of # the new variables will happen in a scope controlled by TFF. Ensuring # `tf.Variable`s are created in the graphs that TFF controls is the reason # we need `personalize_fn_dict` to contain no-argument functions that build # the desired `tf.function`s, rather than already built `tf.function`s. py_typecheck.check_callable(personalize_fn_builder) personalize_fn = personalize_fn_builder() py_typecheck.check_callable(personalize_fn) final_metrics[name] = personalize_fn(model, train_data, test_data, context) return final_metrics
async def create_value(self, value, type_spec=None): """Creates a value in this executor. The following kinds of `value` are supported as the input: * An instance of TFF computation proto containing one of the supported sequence intrinsics as its sole body. * An instance of eager TF dataset. * Anything that is supported by the target executor (as a pass-through). * A nested structure of any of the above. Args: value: The input for which to create a value. type_spec: An optional TFF type (required if `value` is not an instance of `typed_object.TypedObject`, otherwise it can be `None`). Returns: An instance of `SequenceExecutorValue` that represents the embedded value. """ if type_spec is None: py_typecheck.check_type(value, typed_object.TypedObject) type_spec = value.type_signature else: type_spec = computation_types.to_type(type_spec) if isinstance(type_spec, computation_types.SequenceType): return SequenceExecutorValue( _SequenceFromPayload(value, type_spec), type_spec) if isinstance(value, pb.Computation): value_type = type_serialization.deserialize_type(value.type) value_type.check_equivalent_to(type_spec) which_computation = value.WhichOneof('computation') # NOTE: If not a supported type of intrinsic, we let it fall through and # be handled by embedding in the target executor (below). if which_computation == 'intrinsic': intrinsic_def = intrinsic_defs.uri_to_intrinsic_def( value.intrinsic.uri) if intrinsic_def is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get( intrinsic_def.uri) if op_type is not None: type_analysis.check_concrete_instance_of( type_spec, intrinsic_def.type_signature) op = op_type(type_spec) return SequenceExecutorValue(op, type_spec) if isinstance(type_spec, computation_types.StructType): if not isinstance(value, structure.Struct): value = structure.from_container(value) elements = structure.flatten(value) element_types = structure.flatten(type_spec) flat_embedded_vals = await asyncio.gather(*[ self.create_value(el, el_type) for el, el_type in zip(elements, element_types) ]) embedded_struct = structure.pack_sequence_as( value, flat_embedded_vals) return await self.create_struct(embedded_struct) target_value = await self._target_executor.create_value( value, type_spec) return SequenceExecutorValue(target_value, type_spec)
def __init__(self, type_spec: computation_types.SequenceType): py_typecheck.check_type(type_spec, computation_types.SequenceType) self._type_signature = type_spec
def building_block_to_computation(building_block): """Converts a computation building block to a computation impl.""" py_typecheck.check_type(building_block, building_blocks.ComputationBuildingBlock) return computation_impl.ComputationImpl(building_block.proto, context_stack_impl.context_stack)
def build_model_delta_optimizer_process( model_fn: _ModelConstructor, model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]], ClientDeltaFn], server_optimizer_fn: _OptimizerConstructor, *, broadcast_process: Optional[measured_process.MeasuredProcess] = None, aggregation_process: Optional[measured_process.MeasuredProcess] = None, model_update_aggregation_factory: Optional[ factory.AggregationFactory] = None, ) -> iterative_process.IterativeProcess: """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)`. aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the model updates on the clients back to the server. It must support the signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be `None` if `model_update_aggregation_factory` is not `None.` model_update_aggregation_factory: An optional `tff.aggregators.WeightedAggregationFactory` that contstructs `tff.templates.AggregationProcess` for aggregating the client model updates on the server. If `None`, uses a default constructed `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must be `None` if `aggregation_process` is not `None.` Returns: A `tff.templates.IterativeProcess`. Raises: ProcessTypeError: if `broadcast_process` or `aggregation_process` do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER). DisjointArgumentError: if both `aggregation_process` and `model_update_aggregation_factory` are not `None`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) model_weights_type = model_utils.weights_type_from_model(model_fn) if broadcast_process is None: broadcast_process = build_stateless_broadcaster( model_weights_type=model_weights_type) if not _is_valid_broadcast_process(broadcast_process): raise ProcessTypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) if (model_update_aggregation_factory is not None and aggregation_process is not None): raise DisjointArgumentError( 'Must specify only one of `model_update_aggregation_factory` and ' '`AggregationProcess`.') if aggregation_process is None: if model_update_aggregation_factory is None: model_update_aggregation_factory = mean.MeanFactory() py_typecheck.check_type(model_update_aggregation_factory, factory.AggregationFactory.__args__) if isinstance(model_update_aggregation_factory, factory.WeightedAggregationFactory): aggregation_process = model_update_aggregation_factory.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) else: aggregation_process = model_update_aggregation_factory.create( model_weights_type.trainable) process_signature = aggregation_process.next.type_signature input_client_value_type = process_signature.parameter[1] result_server_value_type = process_signature.result[1] if input_client_value_type.member != result_server_value_type.member: raise TypeError('`model_update_aggregation_factory` does not produce a ' 'compatible `AggregationProcess`. The processes must ' 'retain the type structure of the inputs on the ' f'server, but got {input_client_value_type.member} != ' f'{result_server_value_type.member}.') else: next_num_args = len(aggregation_process.next.type_signature.parameter) if next_num_args not in [2, 3]: raise ValueError( f'`next` function of `aggregation_process` must take two (for ' f'unweighted aggregation) or three (for weighted aggregation) ' f'arguments. Found {next_num_args}.') if not _is_valid_model_update_aggregation_process(aggregation_process): raise ProcessTypeError( 'aggregation_process type signature does not conform to expected ' 'signature (<state@S, model_udpate@C> -> <state@S, model_update@S, ' 'measurements@S>). Got: {t}'.format( t=aggregation_process.next.type_signature)) initialize_computation = _build_initialize_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) run_one_round_computation = _build_one_round_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, model_to_client_delta_fn=model_to_client_delta_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) return iterative_process.IterativeProcess( initialize_fn=initialize_computation, next_fn=run_one_round_computation)
def _trees_equal(comp_1, comp_2): """Returns `True` if the computations are entirely identical. If you pass objects other than instances of `building_blocks.ComputationBuildingBlock` this function will return `False`. Structurally equivalent computations with different variable names or different operation orderings are not considered to be equal. Args: comp_1: A `building_blocks.ComputationBuildingBlock` to test. comp_2: A `building_blocks.ComputationBuildingBlock` to test. Raises: TypeError: If `comp_1` or `comp_2` is not an instance of `building_blocks.ComputationBuildingBlock`. NotImplementedError: If `comp_1` and `comp_2` are an unexpected subclass of `building_blocks.ComputationBuildingBlock`. """ # TODO(b/146892021): TFF needs a structural AST equality function, which # needs to be public. There is a necessary dependency on this function from # the TFF-to-TF code generation pipeline, in order to detect some structural # equivalence while generating TensorFlow. It was decided that it is # preferable to expose a dependency on this "private" function, and file the # bug here, rather than effectively duplicate the logic elsewhere. py_typecheck.check_type(comp_1, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(comp_2, building_blocks.ComputationBuildingBlock) if comp_1 is comp_2: return True # The unidiomatic-typecheck is intentional, for the purposes of equality this # function requires that the types are identical and that a subclass will not # be equal to its baseclass. if type(comp_1) != type(comp_2): # pylint: disable=unidiomatic-typecheck return False if comp_1.type_signature != comp_2.type_signature: return False if isinstance(comp_1, building_blocks.Block): if not _trees_equal(comp_1.result, comp_2.result): return False if len(comp_1.locals) != len(comp_2.locals): return False for (name_1, value_1), (name_2, value_2) in zip(comp_1.locals, comp_2.locals): if name_1 != name_2 or not _trees_equal(value_1, value_2): return False return True elif isinstance(comp_1, building_blocks.Call): return (_trees_equal(comp_1.function, comp_2.function) and (comp_1.argument is None and comp_2.argument is None or _trees_equal(comp_1.argument, comp_2.argument))) elif isinstance(comp_1, building_blocks.CompiledComputation): return _compiled_comp_equal(comp_1, comp_2) elif isinstance(comp_1, building_blocks.Data): return comp_1.uri == comp_2.uri elif isinstance(comp_1, building_blocks.Intrinsic): return comp_1.uri == comp_2.uri elif isinstance(comp_1, building_blocks.Lambda): return (comp_1.parameter_name == comp_2.parameter_name and comp_1.parameter_type == comp_2.parameter_type and _trees_equal(comp_1.result, comp_2.result)) elif isinstance(comp_1, building_blocks.Placement): return comp_1.uri == comp_2.uri elif isinstance(comp_1, building_blocks.Reference): return comp_1.name == comp_2.name elif isinstance(comp_1, building_blocks.Selection): return (comp_1.name == comp_2.name and comp_1.index == comp_2.index and _trees_equal(comp_1.source, comp_2.source)) elif isinstance(comp_1, building_blocks.Tuple): # The element names are checked as part of the `type_signature`. if len(comp_1) != len(comp_2): return False for element_1, element_2 in zip(comp_1, comp_2): if not _trees_equal(element_1, element_2): return False return True raise NotImplementedError('Unexpected type found: {}.'.format( type(comp_1)))
def assemble_result_from_graph(type_spec, binding, output_map): """Assembles a result stamped into a `tf.Graph` given type signature/binding. This method does roughly the opposite of `capture_result_from_graph`, in that whereas `capture_result_from_graph` starts with a single structured object made up of tensors and computes its type and bindings, this method starts with the type/bindings and constructs a structured object made up of tensors. Args: type_spec: The type signature of the result to assemble, an instance of `types.Type` or something convertible to it. binding: The binding that relates the type signature to names of tensors in the graph, an instance of `pb.TensorFlow.Binding`. output_map: The mapping from tensor names that appear in the binding to actual stamped tensors (possibly renamed during import). Returns: The assembled result, a Python object that is composed of tensors, possibly nested within Python structures such as anonymous tuples. Raises: TypeError: If the argument or any of its parts are of an uexpected type. ValueError: If the arguments are invalid or inconsistent witch other, e.g., the type and binding don't match, or the tensor is not found in the map. """ type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) py_typecheck.check_type(binding, pb.TensorFlow.Binding) py_typecheck.check_type(output_map, dict) for k, v in output_map.items(): py_typecheck.check_type(k, str) if not tf.is_tensor(v): raise TypeError( 'Element with key {} in the output map is {}, not a tensor.'. format(k, py_typecheck.type_string(type(v)))) binding_oneof = binding.WhichOneof('binding') if type_spec.is_tensor(): if binding_oneof != 'tensor': raise ValueError( 'Expected a tensor binding, found {}.'.format(binding_oneof)) elif binding.tensor.tensor_name not in output_map: raise ValueError( 'Tensor named {} not found in the output map.'.format( binding.tensor.tensor_name)) else: return output_map[binding.tensor.tensor_name] elif type_spec.is_struct(): if binding_oneof != 'struct': raise ValueError( 'Expected a struct binding, found {}.'.format(binding_oneof)) else: type_elements = structure.to_elements(type_spec) if len(binding.struct.element) != len(type_elements): raise ValueError( 'Mismatching tuple sizes in type ({}) and binding ({}).'. format(len(type_elements), len(binding.struct.element))) result_elements = [] for (element_name, element_type), element_binding in zip(type_elements, binding.struct.element): element_object = assemble_result_from_graph( element_type, element_binding, output_map) result_elements.append((element_name, element_object)) if type_spec.python_container is None: return structure.Struct(result_elements) container_type = type_spec.python_container if (py_typecheck.is_named_tuple(container_type) or py_typecheck.is_attrs(container_type)): return container_type(**dict(result_elements)) return container_type(result_elements) elif type_spec.is_sequence(): if binding_oneof != 'sequence': raise ValueError( 'Expected a sequence binding, found {}.'.format(binding_oneof)) else: sequence_oneof = binding.sequence.WhichOneof('binding') if sequence_oneof == 'variant_tensor_name': variant_tensor = output_map[ binding.sequence.variant_tensor_name] return make_dataset_from_variant_tensor( variant_tensor, type_spec.element) else: raise ValueError('Unsupported sequence binding \'{}\'.'.format( sequence_oneof)) else: raise ValueError('Unsupported type \'{}\'.'.format(type_spec))
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from SERVER to CLIENTS. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the SERVER. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), a training `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and a `tf.dataset.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` that maps < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER, where: - model_weights is a `tff.learning.framework.ModelWeights`. - each client's input is an `OrderedDict` of at least two keys `train_data` and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If context is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. - personazliation_metrics is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. - Note: only metrics from at most `max_num_samples` participating clients are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (the personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) # Define the `tff.Type` of each client's input. client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(batch_type)), ('test_data', tff.SequenceType(batch_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" model = model_fn() train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) return _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context) py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval
def get_tf_typespec_and_binding(parameter_type, arg_names, unpack=None): """Computes a `TensorSpec` input_signature and bindings for parameter_type. This is the TF2 analog to `stamp_parameter_in_graph`. Args: parameter_type: The TFF type of the input to a tensorflow function. Must be either an instance of computation_types.Type (or convertible to it), or None in the case of a no-arg function. arg_names: String names for any positional arguments to the tensorflow function. unpack: Whether or not to unpack parameter_type into args and kwargs. See e.g. `function_utils.pack_args_into_struct`. Returns: A tuple (args_typespec, kwargs_typespec, binding), where args_typespec is a list and kwargs_typespec is a dict, both containing `tf.TensorSpec` objects. These structures are intended to be passed to the `get_concrete_function` method of a `tf.function`. Note the "binding" is "preliminary" in that it includes the names embedded in the TensorSpecs produced; these must be converted to the names of actual tensors based on the SignatureDef of the SavedModel before the binding is finalized. """ if parameter_type is None: return ([], {}, None) if unpack: arg_types, kwarg_types = function_utils.unpack_args_from_struct( parameter_type) pack_in_struct = True else: pack_in_struct = False arg_types, kwarg_types = [parameter_type], {} py_typecheck.check_type(arg_names, collections.Iterable) if len(arg_names) < len(arg_types): raise ValueError( 'If provided, arg_names must be a list of at least {} strings to ' 'match the number of positional arguments. Found: {}'.format( len(arg_types), arg_names)) get_unique_name = UniqueNameFn() def _get_one_typespec_and_binding(parameter_name, parameter_type): """Returns a (tf.TensorSpec, binding) pair.""" parameter_type = computation_types.to_type(parameter_type) if parameter_type.is_tensor(): name = get_unique_name(parameter_name) tf_spec = tf.TensorSpec(shape=parameter_type.shape, dtype=parameter_type.dtype, name=name) binding = pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding( tensor_name=name)) return (tf_spec, binding) elif parameter_type.is_struct(): element_typespec_pairs = [] element_bindings = [] have_names = False have_nones = False for e_name, e_type in structure.iter_elements(parameter_type): if e_name is None: have_nones = True else: have_names = True name = '_'.join([n for n in [parameter_name, e_name] if n]) e_typespec, e_binding = _get_one_typespec_and_binding( name if name else None, e_type) element_typespec_pairs.append((e_name, e_typespec)) element_bindings.append(e_binding) # For a given argument or kwarg, we shouldn't have both: if (have_names and have_nones): raise ValueError( 'A mix of named and unnamed entries are not supported inside a ' 'nested structure representing a single argument in a call to a ' 'TensorFlow or Python function.\n{}'.format( parameter_type)) tf_typespec = structure.Struct(element_typespec_pairs) return (tf_typespec, pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding( element=element_bindings))) elif parameter_type.is_sequence(): raise NotImplementedError( 'Sequence iputs not yet supported for TF 2.0.') else: raise ValueError( 'Parameter type component {!r} cannot be converted to a TensorSpec' .format(parameter_type)) def get_arg_name(i): name = arg_names[i] if not isinstance(name, str): raise ValueError( 'arg_names must be strings, but got: {}'.format(name)) return name # Main logic --- process arg_types and kwarg_types: arg_typespecs = [] kwarg_typespecs = {} bindings = [] for i, arg_type in enumerate(arg_types): name = get_arg_name(i) typespec, binding = _get_one_typespec_and_binding(name, arg_type) typespec = type_conversions.type_to_py_container(typespec, arg_type) arg_typespecs.append(typespec) bindings.append(binding) for name, kwarg_type in kwarg_types.items(): typespec, binding = _get_one_typespec_and_binding(name, kwarg_type) typespec = type_conversions.type_to_py_container(typespec, kwarg_type) kwarg_typespecs[name] = typespec bindings.append(binding) assert bindings, 'Given parameter_type {}, but produced no bindings.'.format( parameter_type) if pack_in_struct: final_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding(element=bindings)) else: final_binding = bindings[0] return (arg_typespecs, kwarg_typespecs, final_binding)
def deserialize_and_call_tf_computation(computation_proto, arg, graph): """Deserializes a TF computation and inserts it into `graph`. This method performs an action that can be considered roughly the opposite of what `tensorflow_serialization.serialize_py_func_as_tf_computation` does. At the moment, it simply imports the graph in the current context. A future implementation may rely on different mechanisms. The caller should not be concerned with the specifics of the implementation. At this point, the method is expected to only be used within the body of another TF computation (within an instance of `tf_computation_context.TensorFlowComputationContext` at the top of the stack), and potentially also in certain types of interpreted execution contexts (TBD). Args: computation_proto: An instance of `pb.Computation` with the `computation` one of equal to `tensorflow` to be deserialized and called. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. graph: The graph to stamp into. Returns: A tuple (init_op, result) where: init_op: String name of an op to initialize the graph. result: The results to be fetched from TensorFlow. Depending on the type of the result, this can be `tf.Tensor` or `tf.data.Dataset` instances, or a nested structure (such as an `anonymous_tuple.AnonymousTuple`). Raises: TypeError: If the arguments are of the wrong types. ValueError: If `computation_proto` is not a TensorFlow computation proto. """ py_typecheck.check_type(computation_proto, pb.Computation) computation_oneof = computation_proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise ValueError('Expected a TensorFlow computation, got {}.'.format( computation_oneof)) py_typecheck.check_type(graph, tf.Graph) with graph.as_default(): type_spec = type_serialization.deserialize_type(computation_proto.type) if not type_spec.parameter: if arg is None: input_map = None else: raise TypeError( 'The computation declared no parameters; encountered an unexpected ' 'argument {}.'.format(str(arg))) elif arg is None: raise TypeError( 'The computation declared a parameter of type {}, but the argument ' 'was not supplied.'.format(str(type_spec.parameter))) else: arg_type, arg_binding = graph_utils.capture_result_from_graph( arg, graph) if not type_utils.is_assignable_from(type_spec.parameter, arg_type): raise TypeError( 'The computation declared a parameter of type {}, but the argument ' 'is of a mismatching type {}.'.format( str(type_spec.parameter), str(arg_type))) else: input_map = { k: graph.get_tensor_by_name(v) for k, v in six.iteritems( graph_utils.compute_map_from_bindings( computation_proto.tensorflow.parameter, arg_binding)) } return_elements = graph_utils.extract_tensor_names_from_binding( computation_proto.tensorflow.result) orig_init_op_name = computation_proto.tensorflow.initialize_op if orig_init_op_name: return_elements.append(orig_init_op_name) # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information # about collections, and hence, when we import a graph with Variables, # those Variables are not added to global collections, and hence # functions like tf.global_variables_initializers() will not # contain their initialization ops. output_tensors = tf.import_graph_def( computation_proto.tensorflow.graph_def, input_map, return_elements, # N. B. It is very important not to return any names from the original # computation_proto.tensorflow.graph_def, those names might or might not # be valid in the current graph. Using a different scope makes the graph # somewhat more readable, since _N style de-duplication of graph # node names is less likely to be needed. name='subcomputation') output_map = {k: v for k, v in zip(return_elements, output_tensors)} new_init_op_name = output_map.pop(orig_init_op_name, None) return (new_init_op_name, graph_utils.assemble_result_from_graph( type_spec.result, computation_proto.tensorflow.result, output_map))
def make_data_set_from_elements(graph, elements, element_type): """Creates a `tf.data.Dataset` in `graph` from explicitly listed `elements`. Note: The underlying implementation attempts to use the `tf.data.Dataset.from_tensor_slices() method to build the data set quickly, but this doesn't always work. The typical scenario where it breaks is one with data set being composed of unequal batches. Typically, only the last batch is odd, so on the first attempt, we try to construct two data sets, one from all elements but the last one, and one from the last element, then concatenate the two. In the unlikely case that this fails (e.g., because all data set elements are batches of unequal sizes), we revert to the slow, but reliable method of constructing data sets from singleton elements, and then concatenating them all. Args: graph: The graph in which to construct the `tf.data.Dataset`, or `None` if the construction is to happen in the eager context. elements: A list of elements. element_type: The type of elements. Returns: The constructed `tf.data.Dataset` instance. Raises: TypeError: If element types do not match `element_type`. ValueError: If the elements are of incompatible types and shapes, or if no graph was specified outside of the eager context. """ # Note: We allow the graph to be `None` to allow this function to be used in # the eager context. if graph is not None: py_typecheck.check_type(graph, tf.Graph) elif not tf.executing_eagerly(): raise ValueError('Only in eager context may the graph be `None`.') py_typecheck.check_type(elements, list) element_type = computation_types.to_type(element_type) py_typecheck.check_type(element_type, computation_types.Type) def _make(element_subset): lists = make_empty_list_structure_for_element_type_spec(element_type) for el in element_subset: append_to_list_structure_for_element_type_spec( lists, el, element_type) tensor_slices = replace_empty_leaf_lists_with_numpy_arrays( lists, element_type) return tf.data.Dataset.from_tensor_slices(tensor_slices) def _work(): # pylint: disable=missing-docstring if not elements: # Just return an empty data set with the appropriate types. dummy_element = make_dummy_element_for_type_spec(element_type) ds = _make([dummy_element]).take(0) elif len(elements) == 1: ds = _make(elements) else: try: # It is common for the last element to be a batch of a size different # from all the preceding batches. With this in mind, we proactively # single out the last element (optimizing for the common case). ds = _make(elements[0:-1]).concatenate(_make(elements[-1:])) except ValueError: # In case elements beyond just the last one are of unequal shapes, we # may have failed (the most likely cause), so fall back onto the slow # process of constructing and joining data sets from singletons. Not # optimizing this for now, as it's very unlikely in scenarios # we're targeting. # # Note: this will not remain `None` because `element`s is not empty. ds = None ds = typing.cast(tf.data.Dataset, ds) for i in range(len(elements)): singleton_ds = _make(elements[i:i + 1]) ds = singleton_ds if ds is None else ds.concatenate( singleton_ds) ds_element_type = computation_types.to_type(ds.element_spec) if not element_type.is_assignable_from(ds_element_type): raise TypeError( 'Failure during data set construction, expected elements of type {}, ' 'but the constructed data set has elements of type {}.'.format( element_type, ds_element_type)) return ds if graph is not None: with graph.as_default(): return _work() else: return _work()
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) if isinstance(value, intrinsic_defs.IntrinsicDef): if not type_utils.is_concrete_instance_of(type_spec, value.type_signature): raise TypeError( 'Incompatible type {} used with intrinsic {}.'.format( type_spec, value.uri)) else: return FederatingExecutorValue(value, type_spec) if isinstance(value, placement_literals.PlacementLiteral): if type_spec is not None: py_typecheck.check_type(type_spec, computation_types.PlacementType) return FederatingExecutorValue(value, computation_types.PlacementType()) elif isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): if type_spec is None: type_spec = type_serialization.deserialize_type(value.type) which_computation = value.WhichOneof('computation') if which_computation in ['tensorflow', 'lambda']: return FederatingExecutorValue(value, type_spec) elif which_computation == 'reference': raise ValueError( 'Encountered an unexpected unbound references "{}".'. format(value.reference.name)) elif which_computation == 'intrinsic': intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri) if intr is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef) return await self.create_value(intr, type_spec) elif which_computation == 'placement': return await self.create_value( placement_literals.uri_to_placement_literal( value.placement.uri), type_spec) elif which_computation == 'call': parts = [value.call.function] if value.call.argument.WhichOneof('computation'): parts.append(value.call.argument) parts = await asyncio.gather( *[self.create_value(x) for x in parts]) return await self.create_call( parts[0], parts[1] if len(parts) > 1 else None) elif which_computation == 'tuple': element_values = await asyncio.gather( *[self.create_value(x.value) for x in value.tuple.element]) return await self.create_tuple( anonymous_tuple.AnonymousTuple( (e.name if e.name else None, v) for e, v in zip(value.tuple.element, element_values))) elif which_computation == 'selection': which_selection = value.selection.WhichOneof('selection') if which_selection == 'name': name = value.selection.name index = None elif which_selection != 'index': raise ValueError( 'Unrecognized selection type: "{}".'.format( which_selection)) else: index = value.selection.index name = None return await self.create_selection(await self.create_value( value.selection.source), index=index, name=name) else: raise ValueError( 'Unsupported computation building block of type "{}".'. format(which_computation)) else: py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(type_spec, computation_types.FunctionType): raise ValueError( 'Encountered a value of a functional TFF type {} and Python type ' '{} that is not of one of the recognized representations.'. format(type_spec, py_typecheck.type_string(type(value)))) elif isinstance(type_spec, computation_types.FederatedType): children = self._target_executors.get(type_spec.placement) if not children: raise ValueError( 'Placement "{}" is not configured in this executor.'. format(type_spec.placement)) py_typecheck.check_type(children, list) if not type_spec.all_equal: py_typecheck.check_type(value, (list, tuple, set, frozenset)) if not isinstance(value, list): value = list(value) elif isinstance(value, list): raise ValueError( 'An all_equal value should be passed directly, not as a list.' ) else: value = [value for _ in children] if len(value) != len(children): raise ValueError( 'Federated value contains {} items, but the placement {} in this ' 'executor is configured with {} participants.'.format( len(value), type_spec.placement, len(children))) child_vals = await asyncio.gather(*[ c.create_value(v, type_spec.member) for v, c in zip(value, children) ]) return FederatingExecutorValue(child_vals, type_spec) else: child = self._target_executors.get(None) if not child or len(child) > 1: raise RuntimeError( 'Executor is not configured for unplaced values.') else: return FederatingExecutorValue( await child[0].create_value(value, type_spec), type_spec)
def local_executor_factory( num_clients=None, max_fanout=100, num_client_executors=32, server_tf_device=None, client_tf_devices=tuple() ) -> executor_factory.ExecutorFactory: """Constructs an executor factory to execute computations locally. Note: The `tff.federated_secure_sum()` intrinsic is not implemented by this executor. Args: num_clients: The number of clients. If specified, the executor factory function returned by `local_executor_factory` will be configured to have exactly `num_clients` clients. If unspecified (`None`), then the function returned will attempt to infer cardinalities of all placements for which it is passed values. max_fanout: The maximum fanout at any point in the aggregation hierarchy. If `num_clients > max_fanout`, the constructed executor stack will consist of multiple levels of aggregators. The height of the stack will be on the order of `log(num_clients) / log(max_fanout)`. num_client_executors: The number of distinct client executors to run concurrently; executing more clients than this number results in multiple clients having their work pinned on a single executor in a synchronous fashion. server_tf_device: A `tf.config.LogicalDevice` to place server and other computation without explicit TFF placement. client_tf_devices: List/tuple of `tf.config.LogicalDevice` to place clients for simulation. Possibly accelerators returned by `tf.config.list_logical_devices()`. Returns: An instance of `executor_factory.ExecutorFactory` encapsulating the executor construction logic specified above. Raises: ValueError: If the number of clients is specified and not one or larger. """ if server_tf_device is not None: py_typecheck.check_type(server_tf_device, tf.config.LogicalDevice) py_typecheck.check_type(client_tf_devices, (tuple, list)) py_typecheck.check_type(max_fanout, int) py_typecheck.check_type(num_client_executors, int) if num_clients is not None: py_typecheck.check_type(num_clients, int) if max_fanout < 2: raise ValueError('Max fanout must be greater than 1.') unplaced_ex_factory = UnplacedExecutorFactory( use_caching=True, server_device=server_tf_device, client_devices=client_tf_devices) federating_executor_factory = FederatingExecutorFactory( num_client_executors=num_client_executors, unplaced_ex_factory=unplaced_ex_factory, num_clients=num_clients, use_sizing=False) def _factory_fn( cardinalities: executor_factory.CardinalitiesType ) -> executor_base.Executor: return _create_full_stack( cardinalities, max_fanout, stack_func=federating_executor_factory.create_executor, unplaced_ex_factory=unplaced_ex_factory) return executor_factory.ExecutorFactoryImpl(_factory_fn)
def _check_arg_is_anonymous_tuple(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple)
def _make_wrapper(clipping_norm: Union[float, estimation_process.EstimationProcess], inner_agg_factory: factory.AggregationFactory, make_clip_fn: Callable[[factory.ValueType], computation_base.Computation], attribute_prefix: str) -> factory.AggregationFactory: """Constructs an aggregation factory that applies clip_fn before aggregation. Args: clipping_norm: Either a float (for fixed norm) or an `EstimationProcess` (for adaptive norm) that specifies the norm over which the values should be clipped. inner_agg_factory: A factory specifying the type of aggregation to be done after zeroing. make_clip_fn: A callable that takes a value type and returns a tff.computation specifying the clip operation to apply before aggregation. attribute_prefix: A str for prefixing state and measurement names. Returns: An aggregation factory that applies clip_fn before aggregation. """ py_typecheck.check_type(inner_agg_factory, (factory.UnweightedAggregationFactory, factory.WeightedAggregationFactory)) py_typecheck.check_type(clipping_norm, (float, estimation_process.EstimationProcess)) if isinstance(clipping_norm, float): clipping_norm_process = _constant_process(clipping_norm) else: clipping_norm_process = clipping_norm _check_norm_process(clipping_norm_process, 'clipping_norm_process') # The aggregation factory that will be used to count the number of clipped # values at each iteration. For now we are just creating it here, but in # the future we may make this customizable to allow DP measurements. clipped_count_agg_factory = sum_factory.SumFactory() clipped_count_agg_process = clipped_count_agg_factory.create( computation_types.to_type(COUNT_TF_TYPE)) prefix = lambda s: attribute_prefix + s def init_fn_impl(inner_agg_process): state = collections.OrderedDict([ (prefix('ing_norm'), clipping_norm_process.initialize()), ('inner_agg', inner_agg_process.initialize()), (prefix('ed_count_agg'), clipped_count_agg_process.initialize()) ]) return intrinsics.federated_zip(state) def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm) # TODO(b/163880757): Remove this when server-only metrics are supported. clipping_norm = intrinsics.federated_mean(clients_clipping_norm) clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip_fn, (value, clients_clipping_norm)) new_clipping_norm_state = clipping_norm_process.next( clipping_norm_state, global_norm) if weight is None: agg_output = inner_agg_process.next(agg_state, clipped_value) else: agg_output = inner_agg_process.next(agg_state, clipped_value, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) new_state = collections.OrderedDict([ (prefix('ing_norm'), new_clipping_norm_state), ('inner_agg', agg_output.state), (prefix('ed_count_agg'), clipped_count_output.state) ]) measurements = collections.OrderedDict([ (prefix('ing'), agg_output.measurements), (prefix('ing_norm'), clipping_norm), (prefix('ed_count'), clipped_count_output.result) ]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements)) if isinstance(inner_agg_factory, factory.WeightedAggregationFactory): class WeightedRobustFactory(factory.WeightedAggregationFactory): """`WeightedAggregationFactory` factory for clipping large values.""" def create( self, value_type: factory.ValueType, weight_type: factory.ValueType ) -> aggregation_process.AggregationProcess: _check_value_type(value_type) py_typecheck.check_type(weight_type, factory.ValueType.__args__) inner_agg_process = inner_agg_factory.create(value_type, weight_type) clip_fn = make_clip_fn(value_type) @computations.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(weight_type)) def next_fn(state, value, weight): return next_fn_impl(state, value, clip_fn, inner_agg_process, weight) return aggregation_process.AggregationProcess(init_fn, next_fn) return WeightedRobustFactory() else: class UnweightedRobustFactory(factory.UnweightedAggregationFactory): """`UnweightedAggregationFactory` factory for clipping large values.""" def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: _check_value_type(value_type) inner_agg_process = inner_agg_factory.create(value_type) clip_fn = make_clip_fn(value_type) @computations.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): return next_fn_impl(state, value, clip_fn, inner_agg_process) return aggregation_process.AggregationProcess(init_fn, next_fn) return UnweightedRobustFactory()
async def compute(self): # TODO(b/153499219): Add support for values of other types than tensors. py_typecheck.check_type(self._type_signature, computation_types.TensorType) return self._value
def __init__(self, upper_bound_threshold: ThresholdEstType, lower_bound_threshold: Optional[ThresholdEstType] = None): """Initializes `SecureSumFactory`. Args: upper_bound_threshold: Either a `int` or `float` Python constant, a Numpy scalar, or a `tff.templates.EstimationProcess`, used for determining the upper bound before summation. lower_bound_threshold: Optional. Either a `int` or `float` Python constant, a Numpy scalar, or a `tff.templates.EstimationProcess`, used for determining the lower bound before summation. If specified, must be the same type as `upper_bound_threshold`. Raises: TypeError: If `upper_bound_threshold` and `lower_bound_threshold` are not instances of one of (`int`, `float` or `tff.templates.EstimationProcess`). ValueError: If `upper_bound_threshold` is provided as a negative constant. """ py_typecheck.check_type(upper_bound_threshold, ThresholdEstType.__args__) if lower_bound_threshold is not None: if not isinstance(lower_bound_threshold, type(upper_bound_threshold)): raise TypeError( f'Provided upper_bound_threshold and lower_bound_threshold ' f'must have the same types, but found:\n' f'type(upper_bound_threshold): {upper_bound_threshold}\n' f'type(lower_bound_threshold): {lower_bound_threshold}') # Configuration specific for aggregating integer types. if _is_integer(upper_bound_threshold): self._config_mode = _Config.INT if lower_bound_threshold is None: _check_positive(upper_bound_threshold) lower_bound_threshold = -1 * upper_bound_threshold else: _check_upper_larger_than_lower(upper_bound_threshold, lower_bound_threshold) self._init_fn = _empty_state self._get_bounds_from_state = _create_get_bounds_const( upper_bound_threshold, lower_bound_threshold) self._update_state = lambda _, __, ___: _empty_state() self._secagg_bitwidth = math.ceil( math.log2(upper_bound_threshold - lower_bound_threshold)) # Configuration specific for aggregating floating point types. else: self._config_mode = _Config.FLOAT if _is_float(upper_bound_threshold): # Bounds specified as Python constants. if lower_bound_threshold is None: _check_positive(upper_bound_threshold) lower_bound_threshold = -1.0 * upper_bound_threshold else: _check_upper_larger_than_lower(upper_bound_threshold, lower_bound_threshold) self._get_bounds_from_state = _create_get_bounds_const( upper_bound_threshold, lower_bound_threshold) self._init_fn = _empty_state self._update_state = lambda _, __, ___: _empty_state() else: # Bounds specified as an EstimationProcess. _check_bound_process(upper_bound_threshold, 'upper_bound_threshold') if lower_bound_threshold is None: self._get_bounds_from_state = _create_get_bounds_single_process( upper_bound_threshold) self._init_fn = upper_bound_threshold.initialize self._update_state = _create_update_state_single_process( upper_bound_threshold) else: _check_bound_process(lower_bound_threshold, 'lower_bound_threshold') self._get_bounds_from_state = _create_get_bounds_two_processes( upper_bound_threshold, lower_bound_threshold) self._init_fn = _create_initial_state_two_processes( upper_bound_threshold, lower_bound_threshold) self._update_state = _create_update_state_two_processes( upper_bound_threshold, lower_bound_threshold)
def coerce_dataset_elements_to_tff_type_spec(dataset, element_type): """Map the elements of a dataset to a specified type. This is used to coerce a `tf.data.Dataset` that may have lost the ordering of dictionary keys back into a `collections.OrderedDict` (required by TFF). Args: dataset: a `tf.data.Dataset` instance. element_type: a `tff.Type` specifying the type of the elements of `dataset`. Must be a `tff.TensorType` or `tff.StructType`. Returns: A `tf.data.Dataset` whose output types are compatible with `element_type`. Raises: ValueError: if the elements of `dataset` cannot be coerced into `element_type`. """ py_typecheck.check_type(dataset, type_conversions.TF_DATASET_REPRESENTATION_TYPES) py_typecheck.check_type(element_type, computation_types.Type) if element_type.is_tensor(): return dataset # This is a similar to `reference_executor.to_representation_for_type`, # look for opportunities to consolidate? def _to_representative_value(type_spec, elements): """Convert to a container to a type understood by TF and TFF.""" if type_spec.is_tensor(): return elements elif type_spec.is_struct(): field_types = structure.to_elements(type_spec) is_all_named = all([name is not None for name, _ in field_types]) if is_all_named: if py_typecheck.is_named_tuple(elements): values = collections.OrderedDict( (name, _to_representative_value(field_type, e)) for (name, field_type), e in zip(field_types, elements)) return type(elements)(**values) else: values = [ (name, _to_representative_value(field_type, elements[name])) for name, field_type in field_types ] return collections.OrderedDict(values) else: return tuple( _to_representative_value(t, e) for t, e in zip(type_spec, elements)) else: raise ValueError( 'Coercing a dataset with elements of expected type {!s}, ' 'produced a value with incompatible type `{!s}. Value: ' '{!s}'.format(type_spec, type(elements), elements)) # tf.data.Dataset of tuples will unwrap the tuple in the `map()` call, so we # must pass a function taking *args. However, if the call was originally only # a single tuple, it is now "double wrapped" and must be unwrapped before # traversing. def _unwrap_args(*args): if len(args) == 1: return _to_representative_value(element_type, args[0]) else: return _to_representative_value(element_type, args) return dataset.map(_unwrap_args)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since # it deals exclusively with eager mode. Incubate here, and potentially move # there, once stable. if device is not None: raise NotImplementedError('Unable to embed TF code on a specific device.') py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError('Expected a computation of type {}, got {}.'.format( str(type_spec), str(comp_type))) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = graph_utils.add_control_deps_for_init_op(graph_def, init_op) return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names) signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w) if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def stamp_parameter_in_graph(parameter_name, parameter_type, graph): """Stamps a parameter of a given type in the given tf.Graph instance. Tensors are stamped as placeholders, sequences are stamped as data sets constructed from string tensor handles, and named tuples are stamped by independently stamping their elements. Args: parameter_name: The suggested (string) name of the parameter to use in determining the names of the graph components to construct. The names that will actually appear in the graph are not guaranteed to be based on this suggested name, and may vary, e.g., due to existing naming conflicts, but a best-effort attempt will be made to make them similar for ease of debugging. parameter_type: The type of the parameter to stamp. Must be either an instance of computation_types.Type (or convertible to it), or None. graph: The instance of tf.Graph to stamp in. Returns: A tuple (val, binding), where 'val' is a Python object (such as a dataset, a placeholder, or a `structure.Struct` that represents a named tuple) that represents the stamped parameter for use in the body of a Python function that consumes this parameter, and the 'binding' is an instance of TensorFlow.Binding that indicates how parts of the type signature relate to the tensors and ops stamped into the graph. Raises: TypeError: If the arguments are of the wrong computation_types. ValueError: If the parameter type cannot be stamped in a TensorFlow graph. """ py_typecheck.check_type(parameter_name, str) py_typecheck.check_type(graph, tf.Graph) if parameter_type is None: return (None, None) parameter_type = computation_types.to_type(parameter_type) if parameter_type.is_tensor(): with graph.as_default(): placeholder = tf.compat.v1.placeholder(dtype=parameter_type.dtype, shape=parameter_type.shape, name=parameter_name) binding = pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding( tensor_name=placeholder.name)) return (placeholder, binding) elif parameter_type.is_struct(): # The parameter_type could be a StructTypeWithPyContainer, however, we # ignore that for now. Instead, the proper containers will be inserted at # call time by function_utils.wrap_as_zero_or_one_arg_callable. if not parameter_type: # Stamps dummy element to "populate" graph, as TensorFlow does not support # empty graphs. dummy_tensor = tf.no_op() element_name_value_pairs = [] element_bindings = [] for e in structure.iter_elements(parameter_type): e_val, e_binding = stamp_parameter_in_graph( '{}_{}'.format(parameter_name, e[0]), e[1], graph) element_name_value_pairs.append((e[0], e_val)) element_bindings.append(e_binding) return (structure.Struct(element_name_value_pairs), pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding( element=element_bindings))) elif parameter_type.is_sequence(): with graph.as_default(): variant_tensor = tf.compat.v1.placeholder(tf.variant, shape=[]) ds = make_dataset_from_variant_tensor(variant_tensor, parameter_type.element) return (ds, pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding( variant_tensor_name=variant_tensor.name))) else: raise ValueError( 'Parameter type component {!r} cannot be stamped into a TensorFlow ' 'graph.'.format(parameter_type))
def _check_arg_is_structure(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_type(arg.internal_representation, structure.Struct)
def pack_args_into_anonymous_tuple(args, kwargs, type_spec=None, context=None): """Packs positional and keyword arguments into an anonymous tuple. If 'type_spec' is not None, it must be a tuple type or something that's convertible to it by computation_types.to_type(). The assignment of arguments to fields of the tuple follows the same rule as during function calls. If 'type_spec' is None, the positional arguments precede any of the keyword arguments, and the ordering of the keyword arguments matches the ordering in which they appear in kwargs. If the latter is an OrderedDict, the ordering will be preserved. On the other hand, if the latter is an ordinary unordered dict, the ordering is arbitrary. Args: args: Positional arguments. kwargs: Keyword arguments. type_spec: The optional type specification (either an instance of computation_types.NamedTupleType or something convertible to it), or None if there's no type. Used to drive the arrangements of args into fields of the constructed anonymous tuple, as noted in the description. context: The optional context (an instance of `context_base.Context`) in which the arguments are being packed. Required if and only if the `type_spec` is not `None`. Returns: An anoymous tuple containing all the arguments. Raises: TypeError: if the arguments are of the wrong computation_types. """ type_spec = computation_types.to_type(type_spec) if not type_spec: return anonymous_tuple.AnonymousTuple([(None, arg) for arg in args] + list(six.iteritems(kwargs))) else: py_typecheck.check_type(type_spec, computation_types.NamedTupleType) py_typecheck.check_type(context, context_base.Context) if not is_argument_tuple(type_spec): raise TypeError( 'Parameter type {} does not have a structure of an argument tuple, ' 'and cannot be populated from multiple positional and keyword ' 'arguments'.format(type_spec)) else: result_elements = [] positions_used = set() keywords_used = set() for index, (name, elem_type) in enumerate( anonymous_tuple.to_elements(type_spec)): if index < len(args): if name is not None and name in kwargs: raise TypeError( 'Argument {} specified twice.'.format(name)) else: arg_value = args[index] result_elements.append( (name, context.ingest(arg_value, elem_type))) positions_used.add(index) elif name is not None and name in kwargs: arg_value = kwargs[name] result_elements.append( (name, context.ingest(arg_value, elem_type))) keywords_used.add(name) elif name: raise TypeError( 'Argument named {} is missing.'.format(name)) else: raise TypeError( 'Argument at position {} is missing.'.format(index)) positions_missing = set(range( len(args))).difference(positions_used) if positions_missing: raise TypeError('Positional arguments at {} not used.'.format( positions_missing)) keywords_missing = set(kwargs.keys()).difference(keywords_used) if keywords_missing: raise TypeError('Keyword arguments at {} not used.'.format( keywords_missing)) return anonymous_tuple.AnonymousTuple(result_elements)