def test_something(self):
     self.assertNotEqual(str(placement_literals.CLIENTS),
                         str(placement_literals.SERVER))
     for literal in [placement_literals.CLIENTS, placement_literals.SERVER]:
         self.assertIs(
             placement_literals.uri_to_placement_literal(literal.uri),
             literal)
def deserialize_type(type_proto):
    """Deserializes 'type_proto' as a computation_types.Type.

  NOTE: Currently only deserialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_proto: An instance of pb.Type or None.

  Returns:
    The corresponding instance of computation_types.Type (or None if the
    argument was None).

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which deserialization is not
      implemented.
  """
    # TODO(b/113112885): Implement deserialization of the remaining types.
    if type_proto is None:
        return None
    py_typecheck.check_type(type_proto, pb.Type)
    type_variant = type_proto.WhichOneof('type')
    if type_variant is None:
        return None
    elif type_variant == 'tensor':
        return computation_types.TensorType(
            dtype=tf.DType(type_proto.tensor.dtype),
            shape=tf.TensorShape(type_proto.tensor.shape))
    elif type_variant == 'sequence':
        return computation_types.SequenceType(
            deserialize_type(type_proto.sequence.element))
    elif type_variant == 'tuple':
        return computation_types.NamedTupleType([
            (lambda k, v: (k, v)
             if k else v)(e.name, deserialize_type(e.value))
            for e in type_proto.tuple.element
        ])
    elif type_variant == 'function':
        return computation_types.FunctionType(
            parameter=deserialize_type(type_proto.function.parameter),
            result=deserialize_type(type_proto.function.result))
    elif type_variant == 'placement':
        return computation_types.PlacementType()
    elif type_variant == 'federated':
        placement_oneof = type_proto.federated.placement.WhichOneof(
            'placement')
        if placement_oneof == 'value':
            return computation_types.FederatedType(
                member=deserialize_type(type_proto.federated.member),
                placement=placement_literals.uri_to_placement_literal(
                    type_proto.federated.placement.value.uri),
                all_equal=type_proto.federated.all_equal)
        else:
            raise NotImplementedError(
                'Deserialization of federated types with placement spec as {} '
                'is not currently implemented yet.'.format(placement_oneof))
    else:
        raise NotImplementedError(
            'Unknown type variant {}.'.format(type_variant))
Example #3
0
 def from_proto(cls, computation_proto):
     _check_computation_oneof(computation_proto, 'placement')
     py_typecheck.check_type(
         type_serialization.deserialize_type(computation_proto.type),
         computation_types.PlacementType)
     return cls(
         placement_literals.uri_to_placement_literal(
             str(computation_proto.placement.uri)))
Example #4
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, intrinsic_defs.IntrinsicDef):
         if not type_utils.is_concrete_instance_of(type_spec,
                                                   value.type_signature):
             raise TypeError(
                 'Incompatible type {} used with intrinsic {}.'.format(
                     str(type_spec), value.uri))
         else:
             return FederatedExecutorValue(value, type_spec)
     if isinstance(value, placement_literals.PlacementLiteral):
         if type_spec is not None:
             py_typecheck.check_type(type_spec,
                                     computation_types.PlacementType)
         return FederatedExecutorValue(value,
                                       computation_types.PlacementType())
     elif isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     elif isinstance(value, pb.Computation):
         if type_spec is None:
             type_spec = type_serialization.deserialize_type(value.type)
         which_computation = value.WhichOneof('computation')
         if which_computation in ['tensorflow', 'lambda']:
             return FederatedExecutorValue(value, type_spec)
         elif which_computation == 'reference':
             raise ValueError(
                 'Encountered an unexpected unbound references "{}".'.
                 format(value.reference.name))
         elif which_computation == 'intrinsic':
             intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
             if intr is None:
                 raise ValueError(
                     'Encountered an unrecognized intrinsic "{}".'.format(
                         value.intrinsic.uri))
             py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
             return await self.create_value(intr, type_spec)
         elif which_computation == 'placement':
             return await self.create_value(
                 placement_literals.uri_to_placement_literal(
                     value.placement.uri), type_spec)
         elif which_computation == 'call':
             parts = [value.call.function]
             if value.call.argument.WhichOneof('computation'):
                 parts.append(value.call.argument)
             parts = await asyncio.gather(
                 *[self.create_value(x) for x in parts])
             return await self.create_call(
                 parts[0], parts[1] if len(parts) > 1 else None)
         elif which_computation == 'tuple':
             element_values = await asyncio.gather(
                 *[self.create_value(x.value) for x in value.tuple.element])
             return await self.create_tuple(
                 anonymous_tuple.AnonymousTuple([
                     (e.name if e.name else None, v)
                     for e, v in zip(value.tuple.element, element_values)
                 ]))
         elif which_computation == 'selection':
             which_selection = value.selection.WhichOneof('selection')
             if which_selection == 'name':
                 name = value.selection.name
                 index = None
             elif which_selection != 'index':
                 raise ValueError(
                     'Unrecognized selection type: "{}".'.format(
                         which_selection))
             else:
                 index = value.selection.index
                 name = None
             return await self.create_selection(await self.create_value(
                 value.selection.source),
                                                index=index,
                                                name=name)
         else:
             raise ValueError(
                 'Unsupported computation building block of type "{}".'.
                 format(which_computation))
     else:
         py_typecheck.check_type(type_spec, computation_types.Type)
         if isinstance(type_spec, computation_types.FunctionType):
             raise ValueError(
                 'Uncountered a value of a functional TFF type {} and Python type '
                 '{} that is not of one of the recognized representations.'.
                 format(str(type_spec),
                        py_typecheck.type_string(type(value))))
         elif isinstance(type_spec, computation_types.FederatedType):
             children = self._target_executors.get(type_spec.placement)
             if not children:
                 raise ValueError(
                     'Placement "{}" is not configured in this executor.'.
                     format(str(type_spec.placement)))
             py_typecheck.check_type(children, list)
             if not type_spec.all_equal:
                 py_typecheck.check_type(value,
                                         (list, tuple, set, frozenset))
                 if not isinstance(value, list):
                     value = list(value)
             elif isinstance(value, list):
                 raise ValueError(
                     'An all_equal value should be passed directly, not as a list.'
                 )
             else:
                 value = [value for _ in children]
             if len(value) != len(children):
                 raise ValueError(
                     'Federated value contains {} items, but the placement {} in this '
                     'executor is configured with {} participants.'.format(
                         len(value), str(type_spec.placement),
                         len(children)))
             child_vals = await asyncio.gather(*[
                 c.create_value(v, type_spec.member)
                 for v, c in zip(value, children)
             ])
             return FederatedExecutorValue(child_vals, type_spec)
         else:
             child = self._target_executors.get(None)
             if not child or len(child) > 1:
                 raise RuntimeError(
                     'Executor is not configured for unplaced values.')
             else:
                 return FederatedExecutorValue(
                     await child[0].create_value(value, type_spec),
                     type_spec)