Exemplo n.º 1
0
 def test_returns_sequence_map(self):
     ref = computation_building_blocks.Reference('x', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg_type = computation_types.SequenceType(tf.int32)
     arg = computation_building_blocks.Data('y', arg_type)
     comp = computation_constructing_utils.create_sequence_map(fn, arg)
     self.assertEqual(comp.tff_repr, 'sequence_map(<(x -> x),y>)')
     self.assertEqual(str(comp.type_signature), 'int32*')
Exemplo n.º 2
0
def create_dummy_called_sequence_map(parameter_name, parameter_type=tf.int32):
    r"""Returns a dummy called sequence map.

               Call
              /    \
  sequence_map      data

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
    fn = create_identity_function(parameter_name, parameter_type)
    arg_type = computation_types.SequenceType(parameter_type)
    arg = computation_building_blocks.Data('data', arg_type)
    return computation_constructing_utils.create_sequence_map(fn, arg)
Exemplo n.º 3
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if isinstance(arg.type_signature, computation_types.SequenceType):
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            return computation_constructing_utils.create_sequence_map(fn, arg)
        elif isinstance(arg.type_signature, computation_types.FederatedType):
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)
            if arg.type_signature.placement is placements.SERVER:
                return self.federated_apply(local_fn, arg)
            elif arg.type_signature.placement is placements.CLIENTS:
                return self.federated_map(local_fn, arg)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    arg.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
Exemplo n.º 4
0
 def test_raises_type_error_with_none_arg(self):
     ref = computation_building_blocks.Reference('x', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_sequence_map(fn, None)
Exemplo n.º 5
0
 def test_raises_type_error_with_none_fn(self):
     arg_type = computation_types.SequenceType(tf.int32)
     arg = computation_building_blocks.Data('y', arg_type)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_sequence_map(None, arg)