Ejemplo n.º 1
0
    def test_save_sharded_array(self, array_shape, num_shards, stride):
        shard_size_divisible = 3

        self.assertEqual(num_shards % stride, 0)

        arrays_per_offset = [
            np.random.random(array_shape) for _ in range(stride)
        ]

        tmp_dir = self.create_tempdir()
        prefix = os.path.join(tmp_dir.full_path, 'test')

        for offset in range(stride):
            data_utils.save_sharded_array(arrays_per_offset[offset], prefix,
                                          num_shards, stride, offset,
                                          shard_size_divisible)

        loaded_array_first_dim = None
        for offset in range(stride):
            loaded_array = data_utils.load_sharded_array(
                prefix + '-?????-of-%05d' % num_shards, stride, offset)
            all_axis_except_first = list(range(1, len(array_shape)))
            sum_all_except_first_axis = np.apply_over_axes(
                np.sum, np.abs(loaded_array), all_axis_except_first)
            sum_all_except_first_axis = sum_all_except_first_axis.reshape(-1)
            is_not_pad = sum_all_except_first_axis > 0
            actual_array = loaded_array[is_not_pad]
            self.assertTrue(np.all(actual_array == arrays_per_offset[offset]))

            if loaded_array_first_dim is None:
                loaded_array_first_dim = loaded_array.shape[0]
            else:
                self.assertEqual(loaded_array_first_dim, loaded_array.shape[0])
Ejemplo n.º 2
0
    def test_loaded_arrays_match_saved(self):
        workdir_obj = self.create_tempdir()
        workdir = workdir_obj.full_path
        pattern = os.path.join(workdir, 'array*')
        array = np.random.rand(self.n_splits * self.data_per_split)
        save_array = array.reshape(self.n_splits, self.data_per_split)
        for split in range(self.n_splits):
            path = os.path.join(workdir, 'array' + str(split))
            np.save(path, save_array[split])

        self.assertTrue(
            np.all(data_utils.load_sharded_array(pattern, 1, 0) == array))
        self.assertTrue(
            np.all(
                data_utils.load_sharded_array(pattern, 1, 1) ==
                save_array[1:].reshape(-1)))
        self.assertTrue(
            np.all(
                data_utils.load_sharded_array(pattern, self.n_splits, 0) ==
                save_array[0]))
Ejemplo n.º 3
0
 def load_array(self, pattern: str):
     """Load sharded array as if it was loaded from multiple processes."""
     process_count = jax.process_count()
     arrays = []
     for process_index in range(process_count):
         arrays.append(
             data_utils.load_sharded_array(
                 pattern, process_count * self.memory_reduction,
                 process_index))
     array = np.stack(arrays, axis=0)
     shape = (-1, ) + arrays[0].shape[1:]
     array = array.reshape(shape)
     return array
Ejemplo n.º 4
0
 def load_array(suffix):
   return data_utils.load_sharded_array(
       os.path.join(tmp_dir.full_path,
                    suffix + '-?????-of-%05d' % num_shards), 1, 0)
Ejemplo n.º 5
0
    def load_memory(config: ml_collections.ConfigDict) -> Dict[str, Any]:
        """Load mention memory."""
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        process_count = jax.process_count()
        # Reduce number of loaded memory shards by this proportion. Total shards
        # must be divisible by memory_reduction * process_count.
        memory_reduction = config.get('memory_reduction', 1)
        process_index = jax.process_index()
        local_devices = jax.local_devices()

        memory_prop = config.get('memory_prop', None)
        rows = encoder_config.rows
        memory_key_dim = encoder_config.memory_key_dim

        memory_arrays = {}
        memory_component_names = [
            'memory_keys', 'memory_identifiers', 'memory_entity_ids'
        ]
        # The following arrays should be converted to integer 32 type. The rest of
        # the arrays will converted to model type (typically, bfloat16 of float32).
        memory_component_int_dtypes = {
            'memory_identifiers', 'memory_entity_ids', 'memory_text_entities'
        }
        patterns = [
            config.memory_key_pattern, config.memory_id_pattern,
            config.memory_entity_id_pattern
        ]

        if encoder_config.separate_memory_values:
            memory_component_names.append('memory_values')
            patterns.append(config.memory_value_pattern)

        if config.get('same_entity_set_retrieval_weight', 0) > 0:
            memory_component_names.append('memory_text_entities')
            patterns.append(config.memory_text_entities_pattern)

        for key, pattern in zip(memory_component_names, patterns):
            memory_arrays[key] = data_utils.load_sharded_array(
                pattern, process_count * memory_reduction, process_index)

        memory_variables = {}

        cpu_device = jax.local_devices(backend='cpu')[0]
        dtype = encoder_config.dtype
        for key, array in memory_arrays.items():
            if memory_prop is not None:
                array = array[:int(memory_prop * array.shape[0])]
            if key == 'memory_keys':
                array = array.reshape(len(local_devices), rows, -1,
                                      memory_key_dim)
            else:
                array = array.reshape((len(local_devices), -1) +
                                      array.shape[1:])
            array = jax.device_put(array, cpu_device)
            if key in memory_component_int_dtypes:
                array = jnp.asarray(array, dtype=jnp.int32)
            else:
                array = jnp.asarray(array, dtype=dtype)
            array = jax.device_put_sharded(list(array), local_devices)
            memory_variables[key] = array
        return memory_variables
Ejemplo n.º 6
0
    def make_collater_fn(
        config: ml_collections.ConfigDict
    ) -> Callable[[Dict[str, tf.Tensor]], Dict[str, tf.Tensor]]:
        """Produces function to preprocess batches.

    For a selected subset of mentions in the batch, we retrieve the
    corresponding mention from the mention memory and include it in the batch.
    These retrieved mentions are then incorporated into the Transformer model
    like retrieved mentions in the Mention Memory encoder would be.

    Args:
      config: contains experiment hyperparameters.

    Returns:
      Function that preprocesses batches to be usable for the model
      (mod casting from tf to jnp dtype).
    """
        mm_collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(config)  # pylint: disable=line-too-long
        if config.model_config.encoder_config.get('no_retrieval', False):
            return mm_collater_fn
        max_retrieval_indices = config.max_retrieval_indices

        memory_table = data_utils.load_sharded_array(
            pattern=config.memory_pattern,
            stride=config.memory_reduction,
            offset=0)
        memory_hash = data_utils.load_sharded_array(
            pattern=config.memory_hash_pattern,
            stride=config.memory_reduction,
            offset=0)

        logging.info('Sorting hash array')
        hash_sorted_idx = np.argsort(memory_hash)
        memory_hash_sorted = memory_hash[hash_sorted_idx]

        # Add maximal integer value, so that if hash is greater than largest hash in
        # memory, we just take the first vector. We set the weight of this to zero
        # later so it doesn't affect the results.
        memory_hash_sorted = np.concatenate(
            (memory_hash_sorted, [np.iinfo(np.int32).max])).astype(np.int32)

        hash_sorted_idx = np.concatenate(
            (hash_sorted_idx, [0])).astype(np.int32)

        memory_table = tf.constant(memory_table)
        memory_hash_sorted = tf.constant(memory_hash_sorted)
        hash_sorted_idx = tf.constant(hash_sorted_idx)

        memory_entity_pattern = config.get('memory_entity_pattern', None)
        if memory_entity_pattern:
            memory_entity_ids = data_utils.load_sharded_array(
                pattern=config.memory_entity_pattern,
                stride=config.memory_reduction,
                offset=0)

        def collater_fn(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
            batch = mm_collater_fn(batch)

            retrieve_masked = config.get('retrieve_masked', False)

            # Subselect mentions for which to retrieve corresponding memory.
            # We want to sample mentions which are linked, not masked, and not padded.
            scores = tf.random.uniform(
                tf.shape(batch['mention_target_is_masked'])) + 2 * tf.cast(
                    batch['mention_target_weights'], tf.float32)

            if not retrieve_masked:
                scores -= tf.cast(batch['mention_target_is_masked'],
                                  tf.float32)

            _, mention_target_retrieval_indices = tf.math.top_k(
                scores, k=max_retrieval_indices)

            mention_retrieval_indices = tf.gather(
                batch['mention_target_indices'],
                mention_target_retrieval_indices)
            retrieval_mention_mask = tf.gather(
                batch['mention_target_weights'],
                mention_target_retrieval_indices)
            # set weight to 0 for masked retrievals if we do not want to include these
            if not retrieve_masked:
                retrieval_mention_mask *= tf.gather(
                    1 - tf.cast(batch['mention_target_is_masked'], tf.int32),
                    mention_target_retrieval_indices)

            retrieval_mention_start_positions = tf.gather(
                batch['mention_start_positions'], mention_retrieval_indices)
            retrieval_text_identifiers = tf.gather(batch['text_identifiers'],
                                                   mention_retrieval_indices)
            retrieval_mention_hash = mention_preprocess_utils.modified_cantor_pairing(
                tf.cast(retrieval_mention_start_positions, tf.int64),
                retrieval_text_identifiers)
            retrieval_mention_hash = tf.cast(retrieval_mention_hash, tf.int32)

            retrieval_mention_sort_ids = tf.searchsorted(
                memory_hash_sorted, retrieval_mention_hash)

            # Searchsorted does not check whether value is present in array, just
            # finds insertion point. Here we check and set to default retrieval if not
            # present.
            hash_not_present_mask = tf.not_equal(
                retrieval_mention_hash,
                tf.gather(memory_hash_sorted, retrieval_mention_sort_ids))
            hash_not_present = tf.where(hash_not_present_mask)
            update_values = tf.fill((tf.shape(hash_not_present)[0], ),
                                    tf.shape(hash_sorted_idx)[0] - 1)
            retrieval_mention_sort_ids = tf.tensor_scatter_nd_update(
                retrieval_mention_sort_ids, hash_not_present, update_values)

            # Set mask to 0 if no mention is found
            batch['retrieval_mention_mask'] = retrieval_mention_mask * (
                1 - tf.cast(hash_not_present_mask, tf.int32))

            retrieval_mention_ids = tf.gather(hash_sorted_idx,
                                              retrieval_mention_sort_ids)
            retrieval_mention_values = tf.gather(memory_table,
                                                 retrieval_mention_ids)
            # Match passage entity_ids with memory entity ids as sanity check.
            if memory_entity_pattern:
                retrieval_memory_entity_ids = tf.gather(
                    memory_entity_ids, retrieval_mention_ids)
                retrieval_passage_entity_ids = tf.gather(
                    tf.cast(batch['mention_target_ids'], tf.int32),
                    mention_target_retrieval_indices)
                entity_does_not_match = tf.not_equal(
                    retrieval_memory_entity_ids, retrieval_passage_entity_ids)

                batch['entity_does_not_match'] = tf.logical_and(
                    entity_does_not_match,
                    tf.cast(batch['retrieval_mention_mask'], tf.bool))

            batch['retrieval_mention_values'] = retrieval_mention_values
            batch['retrieval_mention_scores'] = tf.ones_like(
                batch['retrieval_mention_mask'])
            batch['retrieval_mention_batch_positions'] = tf.gather(
                batch['mention_batch_positions'], mention_retrieval_indices)
            batch['retrieval_mention_start_positions'] = retrieval_mention_start_positions  # pylint: disable=line-too-long
            batch['retrieval_mention_end_positions'] = tf.gather(
                batch['mention_end_positions'], mention_retrieval_indices)
            batch['mention_retrieval_indices'] = mention_retrieval_indices

            return batch

        return collater_fn