Пример #1
0
 def setUp(self):
   super().setUp()
   test_utils.force_multi_devices(self.n_devices)
   self.devices = jax.local_devices()
   mention_batch_positions = [
       np.random.randint(self.batch_size, size=(self.n_mentions))
       for _ in range(self.n_devices)
   ]
   self.mention_batch_positions_sharded = jax.device_put_sharded(
       mention_batch_positions, self.devices)
Пример #2
0
 def setUp(self):
     super().setUp()
     test_utils.force_multi_devices(self.n_devices)
     self.devices = jax.local_devices()
     # pylint: disable=g-long-lambda
     (self.mention_encodings_stacked, self.mention_encodings_sharded
      ) = self._gen_array(lambda: 10.0 * np.random.random(
          (self.n_mentions, self.hidden_size)))
     (self.mention_target_ids_stacked, self.mention_target_ids_sharded
      ) = self._gen_array(lambda: np.random.randint(self.entity_vocab_size,
                                                    size=(self.n_mentions)))
     (self.mention_batch_positions_stacked,
      self.mention_batch_positions_sharded
      ) = self._gen_array(lambda: np.random.randint(self.batch_size,
                                                    size=(self.n_mentions)))
     (self.mention_target_is_masked_stacked,
      self.mention_target_is_masked_sharded) = self._gen_array(
          lambda: np.random.randint(2, size=(self.n_mentions)))
    def test_linking_layer(self):
        """Testing linking layer."""

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        encoded_input = jnp.ones(shape=(self.n_devices, self.bsz, self.seq_len,
                                        self.hidden_size),
                                 dtype=self.dtype)
        encoded_input = jax.device_put_sharded(list(encoded_input), devices)
        mention_batch_positions = np.random.randint(self.bsz,
                                                    size=(self.n_devices,
                                                          self.n_mentions))
        mention_batch_positions = jax.device_put_sharded(
            list(mention_batch_positions), devices)
        mention_start_positions = np.random.randint(self.seq_len - 1,
                                                    size=(self.n_devices,
                                                          self.n_mentions))
        mention_end_positions = mention_start_positions + 1
        mention_start_positions = jax.device_put_sharded(
            list(mention_start_positions), devices)
        mention_end_positions = jax.device_put_sharded(
            list(mention_end_positions), devices)
        mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions))
        mention_mask = jax.device_put_sharded(list(mention_mask), devices)
        mention_entity_ids = jnp.arange(
            self.n_devices * self.n_mentions).reshape(self.n_devices,
                                                      self.n_mentions)
        mention_entity_ids = jax.device_put_sharded(list(mention_entity_ids),
                                                    devices)

        model = memory_extraction_layer.MemoryExtractionLayer(
            memory_key_dim=self.memory_key_dim,
            memory_value_dim=self.memory_value_dim,
            dtype=self.dtype,
        )
        pinit_with_output = jax.pmap(model.init_with_output, axis_name='batch')

        rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(rng, self.n_devices)
        result_dict, _ = pinit_with_output(
            split_rng,
            encoding=encoded_input,
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            mention_entity_ids=mention_entity_ids,
        )

        # Check shapes are as expected
        self.assertSequenceEqual(result_dict['memory_keys'].shape,
                                 (self.n_devices, self.n_devices *
                                  self.n_mentions, self.memory_key_dim))
        self.assertSequenceEqual(result_dict['memory_values'].shape,
                                 (self.n_devices, self.n_devices *
                                  self.n_mentions, self.memory_value_dim))
        self.assertSequenceEqual(
            result_dict['memory_mask'].shape,
            (self.n_devices, self.n_devices * self.n_mentions))
        self.assertSequenceEqual(
            result_dict['memory_entity_ids'].shape,
            (self.n_devices, self.n_devices * self.n_mentions))

        # Memory mask and entity ids should just have been all gathered
        self.assertTrue(
            jnp.all(result_dict['memory_mask'][0].reshape(
                self.n_devices, self.n_mentions) == mention_mask))
        self.assertTrue(
            jnp.all(result_dict['memory_entity_ids'][0].reshape(
                self.n_devices, self.n_mentions) == mention_entity_ids))
Пример #4
0
 def setUp(self):
     super().setUp()
     test_utils.force_multi_devices(self.n_devices)
     self.devices = jax.local_devices()
Пример #5
0
 def test_multi_node_training(self):
     test_utils.force_multi_devices(8)
     trainer.train(self.test_config)
Пример #6
0
    def test_model_shape(
        self,
        separate_memory_values=False,
        num_intermediate_layers=None,
    ):
        """Test loss function runs and produces expected values."""
        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config['model_config']['encoder_config'][
            'num_intermediate_layers'] = num_intermediate_layers
        config = ml_collections.FrozenConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(
            config)

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }

        raw_example = test_utils.gen_mention_pretraining_sample(
            self.text_length,
            self.n_mentions,
            self.n_linked_mentions,
            max_length=encoder_config.max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = {
            key: test_utils.tensor_to_numpy(value)
            for key, value in batch.items()
        }
        batch = {
            key: jax.device_put_replicated(value, devices)
            for key, value in batch.items()
        }

        def model_apply(*args, **kwargs):
            return model.apply(*args, method=model.forward, **kwargs)

        papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2))
        encoded_output, loss_helpers, _ = papply(
            {
                'params': initial_variables['params'],
                'constants': initial_variables['constants'],
            },
            batch,
            True,
        )

        self.assertEqual(
            encoded_output.shape,
            (self.n_devices, config.per_device_batch_size,
             encoder_config.max_length, encoder_config.hidden_size))

        memory_value_dim = encoder_config.memory_value_dim
        memory_key_dim = encoder_config.memory_key_dim
        memory_size = memory_value_dim if memory_value_dim else memory_key_dim
        self.assertEqual(loss_helpers['target_mention_encodings'].shape,
                         (self.n_devices, config.max_mention_targets *
                          config.per_device_batch_size, memory_size))
Пример #7
0
    def test_load_weights(self,
                          separate_memory_values=False,
                          memory_only=False):
        """Test saving and loading model recovers original parameters."""

        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config = ml_collections.ConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config
        rows = encoder_config.rows
        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()
        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }
        n_shards = 4

        tempdir_obj = self.create_tempdir()
        tempdir = tempdir_obj.full_path

        memory_key_base = os.path.join(tempdir, 'memory_keys')
        memory_value_base = os.path.join(tempdir, 'memory_values')
        memory_id_base = os.path.join(tempdir, 'memory_id')
        memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id')
        memory_text_entities_base = os.path.join(tempdir,
                                                 'memory_text_entities')

        unreplicated_variables = jax_utils.unreplicate(initial_variables)
        unreplicated_variables['params'] = unreplicated_variables[
            'params'].unfreeze()

        if memory_only:
            load_weights = 'memory_only'
        else:
            load_weights = os.path.join(tempdir, 'weights')
            checkpoint_utils.save_weights(load_weights,
                                          unreplicated_variables['params'])

        memory_keys = initial_variables['constants']['memory_keys']
        memory_keys = memory_keys.reshape(n_shards, -1,
                                          encoder_config.memory_key_dim)
        memory_values = initial_variables['constants']['memory_values']
        memory_values = memory_values.reshape(n_shards, -1,
                                              encoder_config.memory_key_dim)
        memory_ids = initial_variables['constants'][
            'memory_identifiers'].reshape(n_shards, -1)
        memory_entity_ids = initial_variables['constants'][
            'memory_entity_ids'].reshape(n_shards, -1)
        memory_text_entities = initial_variables['constants'][
            'memory_text_entities'].reshape(
                n_shards, -1, encoder_config.n_memory_text_entities)

        for shard in range(n_shards):
            np.save(memory_key_base + str(shard), memory_keys[shard])
            np.save(memory_value_base + str(shard), memory_values[shard])
            np.save(memory_id_base + str(shard), memory_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_text_entities_base + str(shard),
                    memory_text_entities[shard])

        config.memory_key_pattern = memory_key_base + '*'
        config.memory_value_pattern = memory_value_base + '*'
        config.memory_id_pattern = memory_id_base + '*'
        config.memory_entity_id_pattern = memory_entity_id_base + '*'
        config.memory_text_entities_pattern = memory_text_entities_base + '*'
        config.load_weights = load_weights

        loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights(
            config)

        arrayeq = lambda x, y: jnp.all(x == y)
        constants = {
            key: value
            for key, value in initial_variables['constants'].items()
            if not (key == 'memory_values' and not separate_memory_values)
        }
        comparison_variables = {'constants': constants}
        if not memory_only:
            comparison_variables['params'] = initial_variables[
                'params'].unfreeze()

        self.assertTrue(
            jax.tree_map(arrayeq, loaded_variables, comparison_variables))
Пример #8
0
  def test_mention_memory_layer(self, separate_memory_values):
    """Testing memory attention layer."""

    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    static_argnums = (9) if separate_memory_values else (9, 10)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=static_argnums)

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    n_mentions = mention_start_positions.shape[-1]

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)
    # Make sure id 0 or 1 will be highest scoring
    memory_table[0] = memory_table[0] * 2.0
    memory_table[1] = memory_table[1] * -2.0
    memory_table = jnp.asarray(memory_table, dtype=self.dtype)

    memory_keys = memory_table.reshape(self.n_devices, self.rows,
                                       self.table_size // self.rows,
                                       self.memory_key_dim)

    memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices)
    if separate_memory_values:
      memory_values = memory_table.reshape(self.n_devices, self.table_size,
                                           self.memory_key_dim)
      memory_values = jax.device_put_sharded(list(memory_values), devices)
    else:
      memory_values = None

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    (encoded_output, loss_helpers, _), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,
        text_identifiers=None,
    )

    attention_weights = loss_helpers['memory_attention_weights']
    entity_ids = loss_helpers['top_entity_ids']

    normed_input = encoded_input - 1.0

    # Check input was changed
    self.assertFalse(jnp.allclose(encoded_output, normed_input))

    # Check input was not changed where it should not be
    all_indices = set(
        itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
    # Note that mention positions is the same across all of the devices
    start_indices = set(
        zip(mention_batch_positions[0].tolist(),
            mention_start_positions[0].tolist()))
    non_start_indices = all_indices.difference(start_indices)
    non_start_indices_1, non_start_indices_2 = zip(*non_start_indices)
    non_start_indices_1 = jnp.asarray(non_start_indices_1)
    non_start_indices_2 = jnp.asarray(non_start_indices_2)

    non_start_outputs = encoded_output[:, non_start_indices_1,
                                       non_start_indices_2]
    non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2]
    self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs))

    # Check shapes as expected
    self.assertSequenceEqual(
        encoded_output.shape,
        (self.n_devices, self.bsz, self.seq_len, self.input_dim))

    self.assertSequenceEqual(
        attention_weights.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    self.assertSequenceEqual(
        entity_ids.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    # Check id 0 or 1 retrieved
    self.assertTrue(
        jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1)))

    # Set some text identifiers to 0 and others to 1 so that some are binding
    text_identifiers = np.zeros((n_mentions), dtype=np.int32)
    text_identifiers[:n_mentions // 2] = 1
    text_identifiers = jax.device_put_replicated(text_identifiers, devices)

    # Initialize and run one forward pass of model
    (_, loss_helpers, logging_helpers), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,  # memory_values
        text_identifiers=text_identifiers,
    )
    attention_weights_wid = loss_helpers['memory_attention_weights']
    entity_ids_wid = loss_helpers['top_entity_ids']
    n_disallowed = logging_helpers['n_disallowed'][0]

    # Check no effect on ids
    self.assertTrue(jnp.all(entity_ids == entity_ids_wid))

    # Check id 0 or 1 have 0 scores
    text_identifiers = jnp.expand_dims(text_identifiers, -1)
    score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid
    self.assertAlmostEqual(score_masked.sum(), 0.0)

    # Check number disallowed as expected
    self.assertEqual(n_disallowed, n_mentions // 2)
Пример #9
0
  def test_compare_retrievals_with_numpy(self, seed, k_top_post_selection,
                                         max_text_identifiers,
                                         same_passage_memory_policy):
    """Test whether retrieval results are correct."""
    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()
    n_text_entities_per_memory = 3

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=(9, 10, 13))

    rng = jax.random.PRNGKey(seed)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jax.random.uniform(
        rng,
        shape=(self.n_devices, self.bsz, self.seq_len, self.input_dim),
        dtype=self.dtype)
    mention_batch_positions = jax.random.randint(
        rng, minval=0, maxval=self.bsz, shape=(self.n_devices, self.n_mentions))
    mention_start_positions = jax.random.randint(
        rng,
        minval=0,
        maxval=self.seq_len,
        shape=(self.n_devices, self.n_mentions))
    mention_end_positions = mention_start_positions
    mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions))

    memory_table = jax.random.uniform(
        rng,
        shape=(self.n_devices, self.rows, self.table_size // self.rows,
               self.memory_key_dim))
    memory_entity_ids = jax.random.randint(
        rng,
        minval=0,
        maxval=self.entity_vocab_size,
        shape=(self.n_devices, self.table_size))
    if max_text_identifiers is not None:
      memory_identifiers = jax.random.randint(
          rng,
          minval=0,
          maxval=max_text_identifiers,
          shape=(self.n_devices, self.table_size))
      text_identifiers = jax.random.randint(
          rng,
          minval=0,
          maxval=max_text_identifiers,
          shape=(self.n_devices, self.n_mentions))
    else:
      text_identifiers = None

    if n_text_entities_per_memory is not None:
      memory_text_entities = jax.random.randint(
          rng,
          minval=0,
          maxval=self.entity_vocab_size,
          shape=(self.n_devices, self.table_size, n_text_entities_per_memory))
    else:
      memory_text_entities = None

    encoded_input_sharded = jax.device_put_sharded(list(encoded_input), devices)
    mention_batch_positions_sharded = jax.device_put_sharded(
        list(mention_batch_positions), devices)
    mention_start_positions_sharded = jax.device_put_sharded(
        list(mention_start_positions), devices)
    mention_end_positions_sharded = jax.device_put_sharded(
        list(mention_end_positions), devices)
    mention_mask_sharded = jax.device_put_sharded(list(mention_mask), devices)
    memory_table_sharded = jax.device_put_sharded(list(memory_table), devices)
    memory_entity_ids_sharded = jax.device_put_sharded(
        list(memory_entity_ids), devices)
    if max_text_identifiers is not None:
      memory_identifiers_sharded = jax.device_put_sharded(
          list(memory_identifiers), devices)
      text_identifiers_sharded = jax.device_put_sharded(
          list(text_identifiers), devices)
    else:
      memory_identifiers_sharded = None
      text_identifiers_sharded = None

    if memory_text_entities is not None:
      memory_text_entities_sharded = jax.device_put_sharded(
          list(memory_text_entities), devices)
    else:
      memory_text_entities_sharded = None

    memory_ids = jnp.arange(self.n_devices * self.table_size)
    memory_ids = memory_ids.reshape(self.n_devices, self.table_size)

    (_, loss_helpers, logging_helpers), params = pinit_with_output(
        split_rng,
        encoded_input_sharded,
        mention_batch_positions_sharded,
        mention_start_positions_sharded,
        mention_end_positions_sharded,
        mention_mask_sharded,
        memory_table_sharded,
        memory_identifiers_sharded,
        memory_entity_ids_sharded,
        True,
        None,  # memory_values
        text_identifiers_sharded,
        memory_text_entities_sharded,
        same_passage_memory_policy,
    )

    params = params.unfreeze()['params']

    mention_encodings = []
    for device_id in range(self.n_devices):
      mention_start_encodings = encoded_input[device_id][
          mention_batch_positions[device_id],
          mention_start_positions[device_id]]
      mention_end_encodings = encoded_input[device_id][
          mention_batch_positions[device_id], mention_end_positions[device_id]]
      mention_encodings_on_device = jnp.concatenate(
          [mention_start_encodings, mention_end_encodings], axis=-1)
      mention_encodings_on_device = np.matmul(
          mention_encodings_on_device,
          params['query_projector']['kernel'][device_id])
      mention_encodings_on_device += params['query_projector']['bias'][
          device_id]
      mention_encodings.append(mention_encodings_on_device)

    # [n_devices, n_mentions, memory_key_dim]
    mention_encodings_stacked = jnp.stack(mention_encodings)
    mention_encodings_stacked = mention_encodings_stacked.reshape(
        [self.n_devices * self.n_mentions, self.memory_key_dim])

    # Object which represents a single retrieval result with additional info.
    RetrievedMemory = collections.namedtuple('RetrievedMemory', [
        'device', 'row', 'rowwise_index', 'devicewise_index', 'global_index',
        'score', 'memory', 'entity_id', 'memory_hash',
        'memory_passage_text_entities'
    ])

    num_disallowed_per_device = [0 for _ in range(self.n_devices)]
    # Manually simulate retrieval per every query
    for query_id in range(self.n_devices * self.n_mentions):
      query = mention_encodings_stacked[query_id]
      top_memories_query = []
      # Collect retirevals for a single query on each devices separately
      for device_id in range(self.n_devices):
        top_memories_per_device = []
        for row_id in range(self.rows):
          scores = np.einsum('mh,h->m', memory_table[device_id, row_id], query)
          top_index = np.argmax(scores)
          devicewise_index = row_id * (self.table_size // self.rows) + top_index
          global_index = memory_ids[device_id, devicewise_index]
          self.assertEqual(global_index,
                           devicewise_index + device_id * self.table_size)
          if max_text_identifiers is not None:
            memory_hash = memory_identifiers[device_id, devicewise_index].item()
          else:
            memory_hash = None
          if memory_text_entities is not None:
            memory_passage_text_entities = memory_text_entities[
                device_id, devicewise_index]
          else:
            memory_passage_text_entities = None
          top_memories_per_device.append(
              RetrievedMemory(
                  device=device_id,
                  row=row_id,
                  rowwise_index=top_index,
                  devicewise_index=devicewise_index,
                  global_index=global_index,
                  score=scores[top_index].item(),
                  memory=memory_table[device_id, row_id, top_index],
                  entity_id=memory_entity_ids[device_id,
                                              devicewise_index].item(),
                  memory_hash=memory_hash,
                  memory_passage_text_entities=memory_passage_text_entities,
              ))
        # Sort by score. In case two scores are equal (likely because both
        # were considered "disallowed" we compare by entity IDs.
        top_memories_per_device.sort(
            key=lambda x: (x.score, x.entity_id), reverse=True)
        top_memories_per_device = top_memories_per_device[:self.k_top_device]
        top_memories_query.extend(top_memories_per_device)

      top_memories_query.sort(
          key=lambda x: (x.score, x.entity_id), reverse=True)
      if k_top_post_selection is not None:
        top_memories_query = top_memories_query[:k_top_post_selection]

      if max_text_identifiers is not None:
        num_current_disallowed = 0
        text_id = text_identifiers[query_id // self.n_mentions,
                                   query_id % self.n_mentions].item()
        for i in range(len(top_memories_query)):
          if top_memories_query[i].memory_hash == text_id:
            num_current_disallowed += 1

          if ((same_passage_memory_policy == 'disallow' and
               top_memories_query[i].memory_hash == text_id) or
              (same_passage_memory_policy == 'only' and
               top_memories_query[i].memory_hash != text_id)):
            top_memories_query[i] = top_memories_query[i]._replace(
                score=-_LARGE_NUMBER)
        num_disallowed_per_device[query_id //
                                  self.n_mentions] += num_current_disallowed
        top_memories_query.sort(
            key=lambda x: (x.score, x.global_index), reverse=True)

      actual_entity_ids = loss_helpers['top_entity_ids'][query_id //
                                                         self.n_mentions,
                                                         query_id %
                                                         self.n_mentions]
      actual_memory_ids = loss_helpers['top_memory_ids'][query_id //
                                                         self.n_mentions,
                                                         query_id %
                                                         self.n_mentions]
      actual_attention_weights = loss_helpers['memory_attention_weights'][
          query_id // self.n_mentions, query_id % self.n_mentions]

      # We sort retrieved results first by the attention score and then
      # by memory ID.
      p = list(range(len(actual_attention_weights)))
      # pylint: disable=cell-var-from-loop
      p.sort(
          key=lambda i: (actual_attention_weights[i], actual_memory_ids[i]),
          reverse=True)
      # pylint: enable=cell-var-from-loop
      p = np.array(p)

      actual_entity_ids = list(actual_entity_ids[p])
      actual_attention_weights = list(actual_attention_weights[p])
      actual_memory_ids = list(actual_memory_ids[p])

      expected_entity_ids = [x.entity_id for x in top_memories_query]
      self.assertSequenceEqual(expected_entity_ids, actual_entity_ids)

      expected_attention_weights = scipy.special.softmax(
          [x.score for x in top_memories_query])
      self.assertSequenceAlmostEqual(expected_attention_weights,
                                     actual_attention_weights, 5)

      expected_memory_ids = [x.global_index for x in top_memories_query]
      self.assertSequenceEqual(expected_memory_ids, actual_memory_ids)

      actual_top_text_entities = loss_helpers['memory_top_text_entities'][
          query_id // self.n_mentions, query_id % self.n_mentions]
      actual_top_text_entities = actual_top_text_entities[p]

      expected_top_text_entities = [
          x.memory_passage_text_entities for x in top_memories_query
      ]
      self.assertEqual(
          len(actual_top_text_entities), len(expected_top_text_entities))

      # Comparing `actual_top_text_entities` and `expected_top_text_entities`
      # directly is troublesome since we cannot gurantee the order for
      # retrieval results with the same attention weights. Therefore, we first
      # sort both `actual_top_text_entities` and `expected_attention_weights`
      # by the attention weight first and then by their elements.
      def sort_entities(top_text_entities):
        result = [(expected_attention_weights[i], list(top_text_entities[i]))
                  for i in range(len(top_text_entities))]
        result.sort()
        return [x[1] for x in result]

      actual_top_text_entities = sort_entities(actual_top_text_entities)
      expected_top_text_entities = sort_entities(expected_top_text_entities)

      for i in range(len(actual_top_text_entities)):
        self.assertSequenceEqual(
            list(actual_top_text_entities[i]),
            list(expected_top_text_entities[i]))

    if max_text_identifiers is not None:
      self.assertSequenceEqual(num_disallowed_per_device,
                               logging_helpers['n_disallowed'])
Пример #10
0
  def test_memory_attention_backward(self):
    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    pinit = jax.pmap(
        model.init, axis_name='batch', static_broadcasted_argnums=(9, 10))

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)

    memory_table = jnp.asarray(memory_table, dtype=self.dtype)
    memory_table = memory_table.reshape(self.n_devices, self.rows,
                                        self.table_size // self.rows,
                                        self.memory_key_dim)
    memory_table_sharded = jax.device_put_sharded(list(memory_table), devices)

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    initial_parameters = pinit(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_table_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        None,  # memory_values
        text_identifiers=None,
    )

    def step_fn(
        params,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys,
        memory_identifiers,
        memory_entity_ids,
    ):

      def loss_fn(params):
        encoded_output, _, _ = model.apply(
            {'params': params},
            rngs=None,
            encoded_input=encoded_input,
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            memory_keys=memory_keys,
            memory_identifiers=memory_identifiers,
            memory_entity_ids=memory_entity_ids,
            deterministic=True,
            text_identifiers=None,
        )
        return encoded_output.sum()

      loss, grad = jax.value_and_grad(loss_fn)(params)
      return loss, grad

    pstep = jax.pmap(step_fn, axis_name='batch')

    _ = pstep(
        initial_parameters['params'],
        encoded_input=encoded_input,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        memory_keys=memory_table_sharded,
        memory_identifiers=memory_identifiers,
        memory_entity_ids=memory_entity_ids,
    )
Пример #11
0
 def setUp(self):
     super().setUp()
     test_utils.force_multi_devices(self.n_devices)
Пример #12
0
    def test_loss_fn(
        self,
        k_top,
        num_intermediate_layers=None,
        shared_initial_encoder=True,
        shared_intermediate_encoder=True,
        shared_final_encoder=True,
        no_retrieval=False,
        same_passage_retrieval_policy='allow',
        extract_unlinked_mentions=False,
        no_retrieval_for_masked_mentions=False,
    ):
        """Test loss function runs and produces expected values."""
        config = copy.deepcopy(self.config)
        encoder_config = copy.deepcopy(self.encoder_config)
        encoder_config['k_top'] = k_top
        encoder_config['num_intermediate_layers'] = num_intermediate_layers
        encoder_config['shared_initial_encoder'] = shared_initial_encoder
        encoder_config[
            'shared_intermediate_encoder'] = shared_intermediate_encoder
        encoder_config['shared_final_encoder'] = shared_final_encoder
        encoder_config['no_retrieval'] = no_retrieval
        encoder_config[
            'same_passage_retrieval_policy'] = same_passage_retrieval_policy
        encoder_config['extract_unlinked_mentions'] = extract_unlinked_mentions
        encoder_config[
            'no_retrieval_for_masked_mentions'] = no_retrieval_for_masked_mentions
        config['model_config']['encoder_config'] = encoder_config
        if no_retrieval:
            config['el_im_weight'] = 0
        if num_intermediate_layers is not None:
            config['second_el_im_weight'] = 0.1
        config = ml_collections.FrozenConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config

        preprocess_fn = readtwice_task.ReadTwiceTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = readtwice_task.ReadTwiceTask.make_collater_fn(config)
        postprocess_fn = readtwice_task.ReadTwiceTask.make_output_postprocess_fn(
            config)

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        model = readtwice_task.ReadTwiceTask.build_model(model_config)
        dummy_input = readtwice_task.ReadTwiceTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        raw_example = test_utils.gen_mention_pretraining_sample(
            self.text_length,
            self.n_mentions,
            self.n_linked_mentions,
            max_length=encoder_config.max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = {
            key: test_utils.tensor_to_numpy(value)
            for key, value in batch.items()
        }
        batch = {
            key: jax.device_put_replicated(value, devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(readtwice_task.ReadTwiceTask.make_loss_fn(config),
                           'batch',
                           static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {},  # model vars
            batch,
            True,  # deterministic
        )

        take_first = lambda x: x[0]
        metrics = jax.tree_map(take_first, metrics)
        np_batch = jax.tree_map(take_first, batch)

        # mlm losses
        expected_mlm_denom = np_batch['mlm_target_weights'].sum()
        expected_mlm_mention_denom = (np_batch['mlm_target_weights'] *
                                      np_batch['mlm_target_is_mention']).sum()
        expected_mlm_non_mention_denom = (
            np_batch['mlm_target_weights'] *
            (1 - np_batch['mlm_target_is_mention'])).sum()
        self.assertEqual(metrics['mlm']['denominator'], expected_mlm_denom)
        self.assertEqual(metrics['mlm_mention']['denominator'],
                         expected_mlm_mention_denom)
        self.assertEqual(metrics['mlm_non_mention']['denominator'],
                         expected_mlm_non_mention_denom)
        self.assertEqual(metrics['mlm_first']['denominator'],
                         expected_mlm_denom)
        self.assertEqual(metrics['mlm_mention_first']['denominator'],
                         expected_mlm_mention_denom)
        self.assertEqual(metrics['mlm_non_mention_first']['denominator'],
                         expected_mlm_non_mention_denom)

        # same entity retrieval loss
        if not no_retrieval:
            expected_same_entity_denom = np_batch[
                'mention_target_weights'].sum()
            self.assertEqual(metrics['el_intermediate']['denominator'],
                             expected_same_entity_denom)
            if num_intermediate_layers is not None:
                self.assertEqual(
                    metrics['second_el_intermediate']['denominator'],
                    expected_same_entity_denom)

        # coref losses
        expected_coref_denom = np_batch['mention_target_weights'].sum()
        expected_coref_masked_denom = (
            np_batch['mention_target_weights'] *
            np_batch['mention_target_is_masked']).sum()
        expected_coref_non_masked_denom = (
            np_batch['mention_target_weights'] *
            (1 - np_batch['mention_target_is_masked'])).sum()

        for coref_type in {'key', 'value', 'final'}:
            self.assertEqual(
                metrics[coref_type + '_coref_resolution']['denominator'],
                expected_coref_denom)
            self.assertEqual(
                metrics[coref_type +
                        '_coref_resolution_masked']['denominator'],
                expected_coref_masked_denom)
            self.assertEqual(
                metrics[coref_type +
                        '_coref_resolution_non_masked']['denominator'],
                expected_coref_non_masked_denom)

        # mtb losses
        for mtb_type in {'key', 'value', 'final'}:
            self.assertIn(mtb_type + '_mtb', metrics)
            self.assertIn(mtb_type + '_mtb_masked', metrics)
            self.assertIn(mtb_type + '_mtb_non_masked', metrics)

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])