def test_from_keras_model_succeeds_from_set(self):
     keras_model = tff.simulation.models.mnist.create_keras_model(
         compile_model=False)
     input_spec = _create_input_spec()
     keras_utils.from_keras_model(keras_model=keras_model,
                                  global_layers=set(keras_model.layers),
                                  local_layers=set(),
                                  input_spec=input_spec)
    def test_from_keras_model_fails_compiled(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)
        input_spec = _create_input_spec()

        with self.assertRaisesRegex(ValueError, 'compiled'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers,
                                         local_layers=[],
                                         input_spec=input_spec)
    def test_from_keras_model_fails_missing_variables(self):
        """Ensures failure if global/local layers are missing variables."""
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        with self.assertRaisesRegex(ValueError, 'variables'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers[:-1],
                                         local_layers=[],
                                         input_spec=input_spec)
    def test_from_keras_model_fails_bad_input_spec(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = collections.namedtuple(
            'Batch',
            ['x'])(x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32))

        with self.assertRaisesRegex(ValueError, 'input_spec'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers,
                                         local_layers=[],
                                         input_spec=input_spec)
    def test_from_keras_model_forward_pass(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers[:-1],
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        batch_input = collections.namedtuple('Batch', ['x', 'y'])(
            x=tf.ones(shape=[10, 784], dtype=tf.float32),
            y=tf.zeros(shape=[10, 1], dtype=tf.int32))

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, reconstruction_model.BatchOutput)
        self.assertEqual(batch_output.num_examples, 10)
        self.assertAllEqual(batch_output.labels,
                            tf.zeros(shape=[10, 1], dtype=tf.int32))

        # Change num_examples and labels.
        batch_input = collections.namedtuple('Batch', ['x', 'y'])(
            x=tf.zeros(shape=[5, 784], dtype=tf.float32),
            y=tf.ones(shape=[5, 1], dtype=tf.int32))

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, reconstruction_model.BatchOutput)
        self.assertEqual(batch_output.num_examples, 5)
        self.assertAllEqual(batch_output.labels,
                            tf.ones(shape=[5, 1], dtype=tf.int32))
Exemplo n.º 6
0
 def test_has_only_global_variables_true(self):
     keras_model = tff.simulation.models.mnist.create_keras_model(
         compile_model=False)
     input_spec = _create_input_spec()
     model = keras_utils.from_keras_model(keras_model=keras_model,
                                          global_layers=keras_model.layers,
                                          local_layers=[],
                                          input_spec=input_spec)
     self.assertTrue(reconstruction_utils.has_only_global_variables(model))
Exemplo n.º 7
0
def local_recon_model_fn():
    """Keras MNIST model with final dense layer local."""
    keras_model = tff.simulation.models.mnist.create_keras_model(
        compile_model=False)
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers[:-1],
                                        local_layers=keras_model.layers[-1:],
                                        input_spec=input_spec)
Exemplo n.º 8
0
def global_recon_model_fn():
    """Keras MNIST model with no local variables."""
    keras_model = tff.simulation.models.mnist.create_keras_model(
        compile_model=False)
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers,
                                        local_layers=[],
                                        input_spec=input_spec)
def keras_linear_model_fn():
    """Should produce the same results as `LinearModel`."""
    inputs = tf.keras.layers.Input(shape=[1])
    scaled_input = tf.keras.layers.Dense(1,
                                         use_bias=False,
                                         kernel_initializer='zeros')(inputs)
    outputs = BiasLayer()(scaled_input)
    keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers[:-1],
                                        local_layers=keras_model.layers[-1:],
                                        input_spec=input_spec)
Exemplo n.º 10
0
    def reconstruction_model_fn() -> reconstruction_model.ReconstructionModel:
        matrix_factorization_model = model_builder()

        global_layers = matrix_factorization_model.global_layers
        local_layers = matrix_factorization_model.local_layers
        # Merge local layers into global layers if needed.
        if global_variables_only:
            global_layers.extend(local_layers)
            local_layers = []

        return keras_utils.from_keras_model(
            keras_model=matrix_factorization_model.model,
            global_layers=global_layers,
            local_layers=local_layers,
            input_spec=matrix_factorization_model.input_spec)
    def test_from_keras_model_forward_pass_fails_bad_input_keys(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers,
            local_layers=[],
            input_spec=input_spec)

        batch_input = collections.namedtuple('Batch', ['a', 'b'])(
            a=tf.ones(shape=[10, 784], dtype=tf.float32),
            b=tf.zeros(shape=[10, 1], dtype=tf.int32))

        with self.assertRaisesRegex(KeyError, 'keys'):
            recon_model.forward_pass(batch_input)
  def test_get_local_variables(self):
    keras_model = tff.simulation.models.mnist.create_keras_model(
        compile_model=False)
    input_spec = _create_input_spec()
    model = keras_utils.from_keras_model(
        keras_model=keras_model,
        global_layers=keras_model.layers[:-1],
        local_layers=keras_model.layers[-1:],
        input_spec=input_spec)

    local_weights = reconstruction_utils.get_local_variables(model)

    self.assertIsInstance(local_weights, tff.learning.ModelWeights)
    # The last layer of the Keras model, which is a local Dense layer, contains
    # 2 trainable variables for the weights and bias.
    self.assertEqual(local_weights.trainable,
                     keras_model.trainable_variables[-2:])
    self.assertEmpty(local_weights.non_trainable)
    def test_from_keras_model_properties(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers,
            local_layers=[],
            input_spec=input_spec)

        # Global trainable/non_trainable should include all the variables, and
        # local should be empty.
        self.assertEqual(recon_model.global_trainable_variables,
                         keras_model.trainable_variables)
        self.assertEqual(recon_model.global_non_trainable_variables,
                         keras_model.non_trainable_variables)
        self.assertEmpty(recon_model.local_trainable_variables)
        self.assertEmpty(recon_model.local_non_trainable_variables)
        self.assertEqual(input_spec, recon_model.input_spec)
    def test_from_keras_model_local_layers_properties(self):
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.
            layers[:-1],  # Last Dense layer is local.
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        # Expect last two variables, the weights and bias for the final Dense layer,
        # to be local trainable, and the rest global.
        self.assertEqual(recon_model.global_trainable_variables,
                         keras_model.trainable_variables[:-2])
        self.assertEqual(recon_model.global_non_trainable_variables,
                         keras_model.non_trainable_variables)
        self.assertEqual(recon_model.local_trainable_variables,
                         keras_model.trainable_variables[-2:])
        self.assertEmpty(recon_model.local_non_trainable_variables)
        self.assertEqual(input_spec, recon_model.input_spec)
    def test_from_keras_model_forward_pass_list_input(self):
        """Forward pass still works with a 2-element list batch input."""
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=False)
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers[:-1],
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        batch_input = [
            tf.ones(shape=[10, 784], dtype=tf.float32),
            tf.zeros(shape=[10, 1], dtype=tf.int32)
        ]

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, reconstruction_model.BatchOutput)
        self.assertEqual(batch_output.num_examples, 10)
        self.assertAllEqual(batch_output.labels,
                            tf.zeros(shape=[10, 1], dtype=tf.int32))
Exemplo n.º 16
0
def create_recurrent_reconstruction_model(
    vocab_size: int = 10000,
    num_oov_buckets: int = 1,
    embedding_size: int = 96,
    latent_size: int = 670,
    num_layers: int = 1,
    input_spec=None,
    global_variables_only: bool = False,
    name: str = 'rnn_recon_embeddings',
) -> reconstruction_model.ReconstructionModel:
  """Creates a recurrent model with a partially reconstructed embedding layer.

  Constructs a recurrent model for next word prediction, with the embedding
  layer divided in two parts:
    - A global_embedding, which shares its parameter updates with the server.
    - A locally reconstructed local_embedding layer, reconstructed at the
      beginning of every round, that never leaves the device. This local
      embedding layer corresponds to the out of vocabulary buckets.

  Args:
    vocab_size: Size of vocabulary to use.
    num_oov_buckets: Number of out of vocabulary buckets.
    embedding_size: The size of the embedding.
    latent_size: The size of the recurrent state.
    num_layers: The number of layers.
    input_spec: A structure of `tf.TensorSpec`s specifying the type of arguments
      the model expects. Notice this must be a compound structure of two
      elements, specifying both the data fed into the model to generate
      predictions, as its first element, as well as the expected type of the
      ground truth as its second.
    global_variables_only: If True, the returned `ReconstructionModel` contains
      all model variables as global variables. This can be useful for
      baselines involving aggregating all variables.
    name: (Optional) string to name the returned `tf.keras.Model`.

  Returns:
    `ReconstructionModel` tracking global and local variables for a recurrent
    model.
  """

  if vocab_size < 0:
    raise ValueError('The vocab_size is expected to be greater than, or equal '
                     'to 0. Got {}'.format(vocab_size))

  if num_oov_buckets <= 0:
    raise ValueError('The number of out of vocabulary buckets is expected to '
                     'be greater than 0. Got {}'.format(num_oov_buckets))

  global_layers = []
  local_layers = []

  total_vocab_size = vocab_size + 3  # pad/bos/eos.
  extended_vocab_size = total_vocab_size + num_oov_buckets  # pad/bos/eos + oov.
  inputs = tf.keras.layers.Input(shape=(None,), dtype=tf.int64)

  global_embedding = GlobalEmbedding(
      total_vocab_size=total_vocab_size,
      embedding_dim=embedding_size,
      mask_zero=True,
      name='global_embedding_layer')
  global_layers.append(global_embedding)

  local_embedding = LocalEmbedding(
      input_dim=num_oov_buckets,
      embedding_dim=embedding_size,
      total_vocab_size=total_vocab_size,
      mask_zero=True,
      name='local_embedding_layer')
  local_layers.append(local_embedding)

  projected = tf.keras.layers.Add()(
      [global_embedding(inputs),
       local_embedding(inputs)])

  for i in range(num_layers):
    layer = tf.keras.layers.LSTM(
        latent_size, return_sequences=True, name='lstm_' + str(i))
    global_layers.append(layer)
    processed = layer(projected)
    # A projection changes dimension from rnn_layer_size to
    # input_embedding_size.
    projection = tf.keras.layers.Dense(
        embedding_size, name='projection_' + str(i))
    global_layers.append(projection)
    projected = projection(processed)

  # We predict the OOV tokens as part of the output vocabulary.
  last_layer = tf.keras.layers.Dense(
      extended_vocab_size, activation=None, name='last_layer')
  global_layers.append(last_layer)
  logits = last_layer(projected)

  model = tf.keras.Model(inputs=inputs, outputs=logits, name=name)

  if input_spec is None:
    input_spec = collections.OrderedDict(
        x=tf.TensorSpec(shape=(None,), dtype=tf.int64),
        y=tf.TensorSpec(shape=(None,), dtype=tf.int64))

  # Merge local layers into global layers if needed.
  if global_variables_only:
    global_layers.extend(local_layers)
    local_layers = []

  return keras_utils.from_keras_model(
      keras_model=model,
      global_layers=global_layers,
      local_layers=local_layers,
      input_spec=input_spec)