Пример #1
0
    def test_get_from_first_device(self):
        sharded = {
            'a':
            jax.device_put_sharded(
                list(jnp.arange(16).reshape([jax.local_device_count(), 4])),
                jax.local_devices()),
            'b':
            jax.device_put_sharded(
                list(jnp.arange(8).reshape([jax.local_device_count(), 2])),
                jax.local_devices(),
            ),
        }

        want = {
            'a': jnp.arange(4),
            'b': jnp.arange(2),
        }

        # Get zeroth device content as DeviceArray.
        device_arrays = utils.get_from_first_device(sharded, as_numpy=False)
        jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray),
                     device_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, device_arrays)

        # Get the zeroth device content as numpy arrays.
        numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True)
        jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray),
                     numpy_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
Пример #2
0
    def __next__(self) -> types.NestedArray:
        try:
            if not self.pmapped_user:
                item = next(self.iterator)
                if self.split_fn is None:
                    return jax.device_put(item, self.devices[0])
                item_split = self.split_fn(item)
                return PrefetchingSplit(host=item_split.host,
                                        device=jax.device_put(
                                            item_split.device,
                                            self.devices[0]))

            items = itertools.islice(self.iterator, self.num_devices)
            items = tuple(items)
            if len(items) < self.num_devices:
                raise StopIteration
            if self.split_fn is None:
                return jax.device_put_sharded(tuple(items), self.devices)
            else:
                # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                items_split = (self.split_fn(item) for item in items)
                # (host: (x1, ..., xN), device: (y1, ..., yN)).
                split = tree.map_structure_up_to(PrefetchingSplit(None, None),
                                                 lambda *x: x, *items_split)

                return PrefetchingSplit(host=np.stack(split.host),
                                        device=jax.device_put_sharded(
                                            split.device, self.devices))

        except StopIteration:
            raise

        except Exception:  # pylint: disable=broad-except
            logging.exception('Error for %s', self.iterable)
            raise
Пример #3
0
    def producer():
        """Enqueues batched items from `iterable` on a given thread."""
        try:
            # Build a new iterable for each thread. This is crucial if working with
            # tensorflow datasets because tf.Graph objects are thread local.
            it = iter(iterable)
            while True:
                items = itertools.islice(it, len(devices))
                if not items:
                    break
                if split_fn is None:
                    buffer.put(jax.device_put_sharded(tuple(items), devices))
                else:
                    # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                    items_split = (split_fn(item) for item in items)
                    # (host: (x1, ..., xN), device: (y1, ..., yN)).
                    split = tree.map_structure_up_to(
                        PrefetchingSplit(None, None), lambda *x: x,
                        *items_split)

                    buffer.put(
                        PrefetchingSplit(host=np.stack(split.host),
                                         device=jax.device_put_sharded(
                                             split.device, devices)))
        except Exception as e:  # pylint: disable=broad-except
            logging.exception('Error in producer thread for %s', iterable)
            producer_error.append(e)
        finally:
            buffer.put(end)
Пример #4
0
 def _gen_array(self, gen_fn):
     array = [gen_fn() for _ in range(len(self.devices))]
     array_sharded = jax.device_put_sharded(array, self.devices)
     array_stacked = np.stack(array)
     array_stacked = array_stacked.reshape([-1] +
                                           list(array_stacked.shape[2:]))
     return array_stacked, array_sharded
Пример #5
0
  def test_get_globally_consistent_batch_positions(self, seed, batch_size,
                                                   n_mentions):
    np.random.seed(seed)

    mention_batch_positions_sharded = jax.device_put_sharded(
        list(np.random.randint(batch_size, size=(self.n_devices, n_mentions))),
        self.devices)

    fn = functools.partial(
        mention_utils.get_globally_consistent_batch_positions,
        batch_size=batch_size)

    # Test function in the multi-device setting
    (local_mention_batch_positions, global_mention_batch_positions) = jax.pmap(
        fn, axis_name='batch')(
            mention_batch_positions_sharded)

    for i in range(self.n_devices):
      self.assertArrayEqual(local_mention_batch_positions[i],
                            mention_batch_positions_sharded[i] + batch_size * i)

    local_mention_batch_positions = local_mention_batch_positions.reshape(-1)
    for i in range(self.n_devices):
      self.assertArrayEqual(global_mention_batch_positions[i],
                            local_mention_batch_positions)

    # Test function in the single-device setting
    for i in range(self.n_devices):
      (local_mention_batch_positions,
       global_mention_batch_positions) = fn(mention_batch_positions_sharded[i])
      self.assertArrayEqual(local_mention_batch_positions,
                            mention_batch_positions_sharded[i])
      self.assertArrayEqual(global_mention_batch_positions,
                            mention_batch_positions_sharded[i])
Пример #6
0
 def replicate(self):
     """A context manager to use in a with statement that replicates the variables in this collection to multiple
     devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each
     device.
     Important: replicating also updates the random state in order to have a new one per device.
     """
     replicated, saved_states = [], []
     devices = get_local_devices()
     ndevices = len(devices)
     for v in self:
         if isinstance(v, RandomState):
             replicated.append(jax.device_put_sharded([shard for shard in v.split(ndevices)], devices))
             saved_states.append(v.value)
         else:
             replicated.append(jax.device_put_replicated(v.value, devices))
     self.assign(replicated)
     yield
     visited = set()
     saved_states.reverse()
     for k, v in self.items():
         if isinstance(v, TrainRef):
             v = v.ref
             assert not isinstance(v, TrainRef)
         if id(v) not in visited:  # Careful not to reduce twice in case of a variable and a reference to it.
             if isinstance(v, RandomState):
                 v.assign(saved_states.pop())
             else:
                 v.reduce(v.value)
             visited.add(id(v))
Пример #7
0
  def replicate(self):
    """A context manager to use in a with statement that replicates
    the variables in this collection to multiple devices.

    Important: replicating also updates the random state in order
    to have a new one per device.
    """
    global math
    if math is None: from brainpy import math

    replicated, saved_states = {}, {}
    x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_)
    sharded_x = jax.pmap(lambda x: x, axis_name='device')(x)
    devices = [b.device() for b in sharded_x.device_buffers]
    num_device = len(devices)
    for k, d in self.items():
      if isinstance(d, math.random.RandomState):
        replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices)
        saved_states[k] = d.value
      else:
        replicated[k] = jax.device_put_replicated(d.value, devices)
    self.assign(replicated)
    yield
    visited = set()
    for k, d in self.items():
      # Careful not to reduce twice in case of
      # a variable and a reference to it.
      if id(d) not in visited:
        if isinstance(d, math.random.RandomState):
          d.value = saved_states[k]
        else:
          d.value = reduce_func(d)
        visited.add(id(d))
Пример #8
0
  def test_get_num_common_unique_items_multi_devices(self, seed: int,
                                                     batch_size: int,
                                                     n_mentions: int,
                                                     vocab_size: int):
    np.random.seed(seed)

    batch_positions_sharded = jax.device_put_sharded(
        list(np.random.randint(batch_size, size=(self.n_devices, n_mentions))),
        self.devices)
    ids_sharded = jax.device_put_sharded(
        list(np.random.randint(vocab_size, size=(self.n_devices, n_mentions))),
        self.devices)

    fn = functools.partial(
        mention_utils.get_num_common_unique_items, batch_size=batch_size)

    actual_per_sample, actual_per_mention = jax.pmap(
        fn, axis_name='batch')(batch_positions_sharded, ids_sharded)

    batches = []
    for i in range(self.n_devices):
      batches.append([])
      for j in range(batch_size):
        batches[i].append(set())
      for j in range(n_mentions):
        if ids_sharded[i, j] > 0:
          batches[i][batch_positions_sharded[i, j]].add(ids_sharded[i,
                                                                    j].item())

    for d_i in range(self.n_devices):
      for b_i in range(batch_size):
        for d_j in range(self.n_devices):
          for b_j in range(batch_size):
            self.assertLen(
                batches[d_i][b_i].intersection(batches[d_j][b_j]),
                actual_per_sample[d_i, b_i, d_j * batch_size + b_j].item())

    for d_i in range(self.n_devices):
      for m_i in range(n_mentions):
        b_i = batch_positions_sharded[d_i, m_i].item()
        for d_j in range(self.n_devices):
          for m_j in range(n_mentions):
            b_j = batch_positions_sharded[d_j, m_j].item()
            self.assertLen(
                batches[d_i][b_i].intersection(batches[d_j][b_j]),
                actual_per_mention[d_i, m_i, d_j * n_mentions + m_j].item())
Пример #9
0
def _array_shard(
    x,
    devices = None
):
  """Shards a single array over the first axis across multiple local devices."""
  devices = devices or jax.local_devices()
  x = jnp.asarray(x)
  assert x.shape[0] == len(devices)
  return jax.device_put_sharded(list(x), devices)
Пример #10
0
 def setUp(self):
   super().setUp()
   test_utils.force_multi_devices(self.n_devices)
   self.devices = jax.local_devices()
   mention_batch_positions = [
       np.random.randint(self.batch_size, size=(self.n_mentions))
       for _ in range(self.n_devices)
   ]
   self.mention_batch_positions_sharded = jax.device_put_sharded(
       mention_batch_positions, self.devices)
Пример #11
0
 def run(shared_input, clients):
     for block in _blockify(clients, block_size):
         p_state = p_client_init(
             shared_input,
             jax.device_put_sharded(block.client_input, devices))
         p_step_results = []
         for p_batch, p_mask in block.masked_batches:
             p_state, p_step_result = p_client_step(
                 p_state, jax.device_put_sharded(p_batch, devices),
                 jax.device_put_sharded(p_mask, devices))
             p_step_results.append(p_step_result)
         p_client_output = p_client_final(shared_input, p_state)
         for i in range(len(block.client_id)):
             if not block.client_mask[i]:
                 continue
             client_output, step_results = jax.tree_util.tree_map(
                 lambda x: x[i],  # pylint: disable=cell-var-from-loop
                 (p_client_output, p_step_results))
             yield (block.client_id[i], client_output,
                    step_results[:block.num_batches[i]])
Пример #12
0
    def device_put_sharded(x):
        if not isinstance(x, (jnp.ndarray, np.ndarray)):
            return x

        # Later, device_put_sharded takes a sequence of tensors, one tensor for
        # every local device. So we split it on the zeroth (device) dimension.
        x = np.reshape(x, [jax.local_device_count(), -1, x.shape[2]])
        x_list = np.split(x, x.shape[0], axis=0)

        # Squeeze out the dummy dimension.
        x_list = jax.tree_map(lambda y: np.squeeze(y, axis=0), x_list)

        # Send the sharded array in devices.
        return jax.device_put_sharded(x_list, jax.local_devices())
Пример #13
0
def _replicate(x, devices=None):
    """Replicate an object on each device."""
    x = jax.numpy.array(x)
    if devices is None:
        devices = jax.local_devices()
    return jax.device_put_sharded(len(devices) * [x], devices)
    def test_linking_layer(self):
        """Testing linking layer."""

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        encoded_input = jnp.ones(shape=(self.n_devices, self.bsz, self.seq_len,
                                        self.hidden_size),
                                 dtype=self.dtype)
        encoded_input = jax.device_put_sharded(list(encoded_input), devices)
        mention_batch_positions = np.random.randint(self.bsz,
                                                    size=(self.n_devices,
                                                          self.n_mentions))
        mention_batch_positions = jax.device_put_sharded(
            list(mention_batch_positions), devices)
        mention_start_positions = np.random.randint(self.seq_len - 1,
                                                    size=(self.n_devices,
                                                          self.n_mentions))
        mention_end_positions = mention_start_positions + 1
        mention_start_positions = jax.device_put_sharded(
            list(mention_start_positions), devices)
        mention_end_positions = jax.device_put_sharded(
            list(mention_end_positions), devices)
        mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions))
        mention_mask = jax.device_put_sharded(list(mention_mask), devices)
        mention_entity_ids = jnp.arange(
            self.n_devices * self.n_mentions).reshape(self.n_devices,
                                                      self.n_mentions)
        mention_entity_ids = jax.device_put_sharded(list(mention_entity_ids),
                                                    devices)

        model = memory_extraction_layer.MemoryExtractionLayer(
            memory_key_dim=self.memory_key_dim,
            memory_value_dim=self.memory_value_dim,
            dtype=self.dtype,
        )
        pinit_with_output = jax.pmap(model.init_with_output, axis_name='batch')

        rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(rng, self.n_devices)
        result_dict, _ = pinit_with_output(
            split_rng,
            encoding=encoded_input,
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            mention_entity_ids=mention_entity_ids,
        )

        # Check shapes are as expected
        self.assertSequenceEqual(result_dict['memory_keys'].shape,
                                 (self.n_devices, self.n_devices *
                                  self.n_mentions, self.memory_key_dim))
        self.assertSequenceEqual(result_dict['memory_values'].shape,
                                 (self.n_devices, self.n_devices *
                                  self.n_mentions, self.memory_value_dim))
        self.assertSequenceEqual(
            result_dict['memory_mask'].shape,
            (self.n_devices, self.n_devices * self.n_mentions))
        self.assertSequenceEqual(
            result_dict['memory_entity_ids'].shape,
            (self.n_devices, self.n_devices * self.n_mentions))

        # Memory mask and entity ids should just have been all gathered
        self.assertTrue(
            jnp.all(result_dict['memory_mask'][0].reshape(
                self.n_devices, self.n_mentions) == mention_mask))
        self.assertTrue(
            jnp.all(result_dict['memory_entity_ids'][0].reshape(
                self.n_devices, self.n_mentions) == mention_entity_ids))
Пример #15
0
    def run_model(self, config, entity_vocab_size):
        """Initialize and run the model once, perform sanity checks."""
        np.random.seed(0)

        # Save arrays to test retrieval saver.
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       self.devices)
        memory_entity_ids = memory_identifiers
        config['memory_entity_id_pattern'] = self.save_sharded_array(
            memory_entity_ids, 'memory_entity_id')
        memory_text = np.random.randint(
            config['model_config']['encoder_config']['vocab_size'],
            size=(self.n_devices, self.table_size, self.memory_text_length),
            dtype=np.int32)
        config['memory_text_pattern'] = self.save_sharded_array(
            memory_text, 'memory_text')
        memory_positions = np.random.randint(self.memory_text_length,
                                             size=(self.n_devices,
                                                   self.table_size, 2),
                                             dtype=np.int32)
        config['memory_positions_pattern'] = self.save_sharded_array(
            memory_positions, 'memory_positions')

        config = ml_collections.FrozenConfigDict(config)
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_collater_fn(
            config)
        postprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_output_postprocess_fn(
            config)

        model = mention_based_entity_qa_task.MentionBasedEntityQATask.build_model(
            model_config)
        dummy_input = mention_based_entity_qa_task.MentionBasedEntityQATask.dummy_input(
            config)
        dummy_input = jax.device_put_replicated(dummy_input, self.devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, self.devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, self.devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'encoder': {
                'memory_keys': memory_keys,
                'memory_values': memory_values,
                'memory_identifiers': memory_identifiers,
                'memory_entity_ids': memory_entity_ids,
            }
        }

        def sample_batch():
            processed_examples = []
            for _ in range(config.per_device_batch_size):
                raw_example = test_utils.gen_mention_pretraining_sample(
                    self.text_length,
                    self.n_mentions,
                    self.n_linked_mentions,
                    entity_vocab_size=entity_vocab_size,
                    max_length=encoder_config.max_length)
                processed_example = preprocess_fn(raw_example)
                processed_examples.append(processed_example)
            batch = stack(processed_examples)
            batch = collater_fn(batch)
            batch = {
                key: test_utils.tensor_to_numpy(value)
                for key, value in batch.items()
            }
            return batch

        batch = stack([sample_batch() for _ in range(self.n_devices)])
        batch = {
            key: jax.device_put_sharded(list(value), self.devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(
            mention_based_entity_qa_task.MentionBasedEntityQATask.make_loss_fn(
                config),
            'batch',
            static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {'constants': initial_variables['constants']},
            batch,
            True,
        )

        self.assertArrayEqual(metrics['agg']['denominator'],
                              batch['mention_target_weights'].sum(1))

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])

        n_mentions_per_device = (config.per_device_batch_size *
                                 config.max_mention_targets)
        k_top_final = (encoder_config.final_k_top_post_selection
                       or encoder_config.final_k_top_device * self.n_devices)
        self.assertSequenceEqual(
            np.array(features['memory_text']).shape, [
                self.n_devices, n_mentions_per_device, k_top_final,
                self.memory_text_length
            ])
        self.assertSequenceEqual(
            np.array(features['memory_positions']).shape,
            [self.n_devices, n_mentions_per_device, k_top_final, 2])

        return batch, initial_variables, metrics
Пример #16
0
def bcast_local_devices(value):
    """Broadcasts an object to all local devices."""
    devices = jax.local_devices()
    return jax.tree_map(
        lambda v: jax.device_put_sharded(len(devices) * [v], devices), value)
Пример #17
0
def _device_put_sharded(sharded_tree, devices):
    leaves, treedef = jax.tree_flatten(sharded_tree)
    n = leaves[0].shape[0]
    return jax.device_put_sharded([
        jax.tree_unflatten(treedef, [l[i] for l in leaves]) for i in range(n)
    ], devices)
Пример #18
0
def build_dataset_iterator(
    data_root: str,
    split: str,
    dynamic_batch_size_config: config_dict.ConfigDict,
    online_subsampling_kwargs: dict,  # pylint: disable=g-bare-generic
    debug: bool = False,
    is_training: bool = True,
    k_fold_split_id: Optional[int] = None,
    ratio_unlabeled_data_to_labeled_data: float = 0.0,
    use_all_labels_when_not_training: bool = False,
    use_dummy_adjacencies: bool = False,
):
    """Returns an iterator over Batches from the dataset."""

    if split == 'test':
        use_all_labels_when_not_training = True

    if not is_training:
        ratio_unlabeled_data_to_labeled_data = 0.0

    # Load the master data arrays.
    with LOADING_RAW_ARRAYS_LOCK:
        array_dict = data_utils.get_arrays(
            data_root,
            k_fold_split_id=k_fold_split_id,
            use_dummy_adjacencies=use_dummy_adjacencies)

    node_labels = array_dict['paper_label'].reshape(-1)
    train_indices = array_dict['train_indices'].astype(np.int32)
    is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32)
    is_train_index[train_indices] = 1
    valid_indices = array_dict['valid_indices'].astype(np.int32)
    is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32)
    is_valid_index[valid_indices] = 1
    is_train_or_valid_index = is_train_index + is_valid_index

    def sstable_to_intermediate_graph(graph):
        indices = tf.cast(graph.nodes['index'], tf.int32)
        first_index = indices[..., 0]

        # Add an additional absolute index, but adding offsets to authors, and
        # institution indices.
        absolute_index = graph.nodes['index']
        is_author = graph.nodes['type'] == 1
        absolute_index = tf.where(is_author,
                                  absolute_index + data_utils.NUM_PAPERS,
                                  absolute_index)
        is_institution = graph.nodes['type'] == 2
        absolute_index = tf.where(
            is_institution,
            absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS,
            absolute_index)

        is_same_as_central_node = tf.math.equal(indices, first_index)
        input_nodes = graph.nodes
        graph = graph._replace(
            nodes={
                'one_hot_type':
                tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3),
                'one_hot_depth':
                tf.one_hot(tf.cast(input_nodes['depth'], tf.int32),
                           _MAX_DEPTH_IN_SUBGRAPH),
                'year':
                tf.expand_dims(input_nodes['year'], axis=-1),
                'label':
                tf.one_hot(tf.cast(input_nodes['label'], tf.int32),
                           NUM_CLASSES),
                'is_same_as_central_node':
                is_same_as_central_node,
                # Only first node in graph has a valid label.
                'is_central_node':
                tf.one_hot(0,
                           tf.shape(input_nodes['label'])[0]),
                'index':
                input_nodes['index'],
                'absolute_index':
                absolute_index,
            },
            globals=tf.expand_dims(graph.globals, axis=-1),
        )

        return graph

    ds = data_utils.get_graph_subsampling_dataset(
        split,
        array_dict,
        shuffle_indices=is_training,
        ratio_unlabeled_data_to_labeled_data=
        ratio_unlabeled_data_to_labeled_data,
        max_nodes=dynamic_batch_size_config.n_node - 1,  # Keep space for pads.
        max_edges=dynamic_batch_size_config.n_edge,
        **online_subsampling_kwargs)
    if debug:
        ds = ds.take(50)
    ds = ds.map(sstable_to_intermediate_graph,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if is_training:
        ds = ds.shard(jax.process_count(), jax.process_index())
        ds = ds.shuffle(buffer_size=1 if debug else 128)
        ds = ds.repeat()
    ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE)
    np_ds = iter(tfds.as_numpy(ds))
    batched_np_ds = batching_utils.dynamically_batch(
        np_ds,
        **dynamic_batch_size_config,
    )

    def intermediate_graph_to_batch(graph):
        central_node_mask = graph.nodes['is_central_node']
        label = graph.nodes['label']
        node_indices = graph.nodes['index']
        absolute_indices = graph.nodes['absolute_index']

        ### Construct label as a feature for non-central nodes.
        # First do a lookup with node indices, with a np.minimum to ensure we do not
        # index out of bounds due to num_authors being larger than num_papers.
        is_same_as_central_node = graph.nodes['is_same_as_central_node']
        capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1)
        label_as_feature = node_labels[capped_indices]
        # Nodes which are not in train set should get `num_classes` label.
        # Nodes in test set or non-arXiv nodes have -1 or nan labels.

        # Mask out invalid labels and non-papers.
        use_label_as_feature = np.logical_and(
            label_as_feature >= 0, graph.nodes['one_hot_type'][..., 0])
        if split == 'train' or not use_all_labels_when_not_training:
            # Mask out validation papers and non-arxiv papers who
            # got labels from fusing with arxiv papers.
            use_label_as_feature = np.logical_and(
                is_train_index[capped_indices], use_label_as_feature)
        label_as_feature = np.where(use_label_as_feature, label_as_feature,
                                    NUM_CLASSES)
        # Mask out central node label in case it appears again.
        label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES,
                                    label_as_feature)
        # Nodes which are not papers get `NUM_CLASSES+1` label.
        label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0],
                                    label_as_feature, NUM_CLASSES + 1)

        nodes = {
            'label_as_feature':
            label_as_feature,
            'year':
            graph.nodes['year'],
            'bitstring_year':
            _get_bitstring_year_representation(graph.nodes['year']),
            'one_hot_type':
            graph.nodes['one_hot_type'],
            'one_hot_depth':
            graph.nodes['one_hot_depth'],
        }

        graph = graph._replace(
            nodes=nodes,
            globals={},
        )
        is_train_or_valid_node = np.logical_and(
            is_train_or_valid_index[capped_indices],
            graph.nodes['one_hot_type'][..., 0])
        if is_training:
            label_mask = np.logical_and(central_node_mask,
                                        is_train_or_valid_node)
        else:
            # `label_mask` is used to index into valid central nodes by prediction
            # calculator. Since that computation is only done when not training, and
            # at that time we are guaranteed all central nodes have valid labels,
            # we just set label_mask = central_node_mask when not training.
            label_mask = central_node_mask
        batch = Batch(graph=graph,
                      node_labels=label,
                      central_node_mask=central_node_mask,
                      label_mask=label_mask,
                      node_indices=node_indices,
                      absolute_node_indices=absolute_indices)

        # Transform integers into one-hots.
        batch = _add_one_hot_features_to_batch(batch)

        # Gather PCA features.
        return _add_embeddings_to_batch(batch, array_dict['bert_pca_129'])

    batch_list = []
    for batch in batched_np_ds:
        with jax.profiler.StepTraceAnnotation('batch_postprocessing'):
            batch = intermediate_graph_to_batch(batch)
        if is_training:
            batch_list.append(batch)
            if len(batch_list) == jax.local_device_count():
                yield jax.device_put_sharded(batch_list, jax.local_devices())
                batch_list = []
        else:
            yield batch
Пример #19
0
    def run_model(self, config, entity_vocab_size):
        """Initialize and run the model once, perform sanity checks."""
        np.random.seed(0)

        # Save arrays to test retrieval saver.
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       self.devices)
        memory_entity_ids = memory_identifiers
        config['memory_entity_id_pattern'] = self.save_sharded_array(
            memory_entity_ids, 'memory_entity_id')
        memory_text = np.random.randint(
            config['model_config']['encoder_config']['vocab_size'],
            size=(self.n_devices, self.table_size, self.memory_text_length),
            dtype=np.int32)
        config['memory_text_pattern'] = self.save_sharded_array(
            memory_text, 'memory_text')
        memory_positions = np.random.randint(self.memory_text_length,
                                             size=(self.n_devices,
                                                   self.table_size, 2),
                                             dtype=np.int32)
        config['memory_positions_pattern'] = self.save_sharded_array(
            memory_positions, 'memory_positions')

        config = ml_collections.FrozenConfigDict(config)
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(
            config)
        postprocess_fn = mention_memory_task.MentionMemoryTask.make_output_postprocess_fn(
            config)

        model = mention_memory_task.MentionMemoryTask.build_model(model_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, self.devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, self.devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, self.devices)

        # `memory_text_entities` are assumed to contain unique IDs in the last dim.
        memory_text_entities = np.zeros(
            (self.n_devices, self.table_size,
             encoder_config.n_memory_text_entities), np.int32)
        for device_index in range(self.n_devices):
            for t_index in range(self.table_size):
                current_text_entities = np.random.choice(
                    entity_vocab_size,
                    size=(min(encoder_config.n_memory_text_entities,
                              entity_vocab_size)),
                    replace=False)
                memory_text_entities[device_index,
                                     t_index, :len(current_text_entities
                                                   )] = current_text_entities

        memory_text_entities = jax.device_put_sharded(
            list(memory_text_entities), self.devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'encoder': {
                'memory_keys': memory_keys,
                'memory_values': memory_values,
                'memory_identifiers': memory_identifiers,
                'memory_entity_ids': memory_entity_ids,
                'memory_text_entities': memory_text_entities,
            }
        }

        def sample_batch():
            processed_examples = []
            for _ in range(config.per_device_batch_size):
                raw_example = test_utils.gen_mention_pretraining_sample(
                    self.text_length,
                    self.n_mentions,
                    self.n_linked_mentions,
                    entity_vocab_size=entity_vocab_size,
                    max_length=encoder_config.max_length)
                processed_example = preprocess_fn(raw_example)
                processed_examples.append(processed_example)
            batch = stack(processed_examples)
            batch = collater_fn(batch)
            batch = {
                key: test_utils.tensor_to_numpy(value)
                for key, value in batch.items()
            }
            text_ids = batch['text_ids']
            for i in range(config.per_device_batch_size):
                for j in range(config.max_mlm_targets):
                    if batch['mlm_target_weights'][i, j] > 0:
                        text_ids[i, batch['mlm_target_positions'][
                            i, j]] = batch['mlm_target_ids'][i, j]
            mention_batch_positions = batch['mention_batch_positions']
            text_identifiers = batch['text_identifiers'].astype(
                np.int32).tolist()
            expected_text_identifiers = [
                mention_preprocess_utils.text_hash(
                    text_ids[mention_batch_positions[index]]).astype(np.int32)
                for index in range(len(mention_batch_positions))
            ]
            self.assertSequenceEqual(text_identifiers,
                                     expected_text_identifiers)
            return batch

        batch = stack([sample_batch() for _ in range(self.n_devices)])
        batch = {
            key: jax.device_put_sharded(list(value), self.devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(
            mention_memory_task.MentionMemoryTask.make_loss_fn(config),
            'batch',
            static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {'constants': initial_variables['constants']},
            batch,
            True,
        )

        metrics_per_first_device = jax.tree_map(lambda x: x[0], metrics)
        self.assertEqual(metrics_per_first_device['mlm']['denominator'],
                         batch['mlm_target_weights'][0].sum())

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])

        n_mentions_per_device = (config.per_device_batch_size *
                                 config.max_mentions)
        if config.save_k_retrieval is not None:
            k_top_saved = min(config.save_k_retrieval,
                              encoder_config.k_top_post_selection)
        else:
            k_top_saved = (encoder_config.k_top_post_selection
                           or encoder_config.k_top_device * self.n_devices)
        self.assertSequenceEqual(
            np.array(features['memory_text']).shape, [
                self.n_devices, n_mentions_per_device, k_top_saved,
                self.memory_text_length
            ])
        self.assertSequenceEqual(
            np.array(features['memory_positions']).shape,
            [self.n_devices, n_mentions_per_device, k_top_saved, 2])

        if encoder_config.get('num_intermediate_layers') is not None:
            self.assertSequenceEqual(
                np.array(features['second_memory_text']).shape, [
                    self.n_devices, n_mentions_per_device, k_top_saved,
                    self.memory_text_length
                ])
            self.assertSequenceEqual(
                np.array(features['second_memory_positions']).shape,
                [self.n_devices, n_mentions_per_device, k_top_saved, 2])

        return batch, initial_variables, metrics
Пример #20
0
 def split_and_put(x: jnp.ndarray) -> jnp.ndarray:
     return jax.device_put_sharded(np.split(x[:self._dataset_size],
                                            len(device)),
                                   devices=device)
Пример #21
0
 def _replicate(x):
   """Replicate an object on each device."""
   x = jnp.array(x)
   return jax.device_put_sharded(len(devices) * [x], devices)
Пример #22
0
  def test_memory_attention_backward(self):
    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    pinit = jax.pmap(
        model.init, axis_name='batch', static_broadcasted_argnums=(9, 10))

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)

    memory_table = jnp.asarray(memory_table, dtype=self.dtype)
    memory_table = memory_table.reshape(self.n_devices, self.rows,
                                        self.table_size // self.rows,
                                        self.memory_key_dim)
    memory_table_sharded = jax.device_put_sharded(list(memory_table), devices)

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    initial_parameters = pinit(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_table_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        None,  # memory_values
        text_identifiers=None,
    )

    def step_fn(
        params,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys,
        memory_identifiers,
        memory_entity_ids,
    ):

      def loss_fn(params):
        encoded_output, _, _ = model.apply(
            {'params': params},
            rngs=None,
            encoded_input=encoded_input,
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            memory_keys=memory_keys,
            memory_identifiers=memory_identifiers,
            memory_entity_ids=memory_entity_ids,
            deterministic=True,
            text_identifiers=None,
        )
        return encoded_output.sum()

      loss, grad = jax.value_and_grad(loss_fn)(params)
      return loss, grad

    pstep = jax.pmap(step_fn, axis_name='batch')

    _ = pstep(
        initial_parameters['params'],
        encoded_input=encoded_input,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        memory_keys=memory_table_sharded,
        memory_identifiers=memory_identifiers,
        memory_entity_ids=memory_entity_ids,
    )
Пример #23
0
def replicate_in_all_devices(nest: N,
                             devices: Optional[Sequence[jax.xla.Device]] = None
                             ) -> N:
    """Replicate array nest in all available devices."""
    devices = devices or jax.local_devices()
    return jax.device_put_sharded([nest] * len(devices), devices)
Пример #24
0
  def test_compare_retrievals_with_numpy(self, seed, k_top_post_selection,
                                         max_text_identifiers,
                                         same_passage_memory_policy):
    """Test whether retrieval results are correct."""
    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()
    n_text_entities_per_memory = 3

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=(9, 10, 13))

    rng = jax.random.PRNGKey(seed)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jax.random.uniform(
        rng,
        shape=(self.n_devices, self.bsz, self.seq_len, self.input_dim),
        dtype=self.dtype)
    mention_batch_positions = jax.random.randint(
        rng, minval=0, maxval=self.bsz, shape=(self.n_devices, self.n_mentions))
    mention_start_positions = jax.random.randint(
        rng,
        minval=0,
        maxval=self.seq_len,
        shape=(self.n_devices, self.n_mentions))
    mention_end_positions = mention_start_positions
    mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions))

    memory_table = jax.random.uniform(
        rng,
        shape=(self.n_devices, self.rows, self.table_size // self.rows,
               self.memory_key_dim))
    memory_entity_ids = jax.random.randint(
        rng,
        minval=0,
        maxval=self.entity_vocab_size,
        shape=(self.n_devices, self.table_size))
    if max_text_identifiers is not None:
      memory_identifiers = jax.random.randint(
          rng,
          minval=0,
          maxval=max_text_identifiers,
          shape=(self.n_devices, self.table_size))
      text_identifiers = jax.random.randint(
          rng,
          minval=0,
          maxval=max_text_identifiers,
          shape=(self.n_devices, self.n_mentions))
    else:
      text_identifiers = None

    if n_text_entities_per_memory is not None:
      memory_text_entities = jax.random.randint(
          rng,
          minval=0,
          maxval=self.entity_vocab_size,
          shape=(self.n_devices, self.table_size, n_text_entities_per_memory))
    else:
      memory_text_entities = None

    encoded_input_sharded = jax.device_put_sharded(list(encoded_input), devices)
    mention_batch_positions_sharded = jax.device_put_sharded(
        list(mention_batch_positions), devices)
    mention_start_positions_sharded = jax.device_put_sharded(
        list(mention_start_positions), devices)
    mention_end_positions_sharded = jax.device_put_sharded(
        list(mention_end_positions), devices)
    mention_mask_sharded = jax.device_put_sharded(list(mention_mask), devices)
    memory_table_sharded = jax.device_put_sharded(list(memory_table), devices)
    memory_entity_ids_sharded = jax.device_put_sharded(
        list(memory_entity_ids), devices)
    if max_text_identifiers is not None:
      memory_identifiers_sharded = jax.device_put_sharded(
          list(memory_identifiers), devices)
      text_identifiers_sharded = jax.device_put_sharded(
          list(text_identifiers), devices)
    else:
      memory_identifiers_sharded = None
      text_identifiers_sharded = None

    if memory_text_entities is not None:
      memory_text_entities_sharded = jax.device_put_sharded(
          list(memory_text_entities), devices)
    else:
      memory_text_entities_sharded = None

    memory_ids = jnp.arange(self.n_devices * self.table_size)
    memory_ids = memory_ids.reshape(self.n_devices, self.table_size)

    (_, loss_helpers, logging_helpers), params = pinit_with_output(
        split_rng,
        encoded_input_sharded,
        mention_batch_positions_sharded,
        mention_start_positions_sharded,
        mention_end_positions_sharded,
        mention_mask_sharded,
        memory_table_sharded,
        memory_identifiers_sharded,
        memory_entity_ids_sharded,
        True,
        None,  # memory_values
        text_identifiers_sharded,
        memory_text_entities_sharded,
        same_passage_memory_policy,
    )

    params = params.unfreeze()['params']

    mention_encodings = []
    for device_id in range(self.n_devices):
      mention_start_encodings = encoded_input[device_id][
          mention_batch_positions[device_id],
          mention_start_positions[device_id]]
      mention_end_encodings = encoded_input[device_id][
          mention_batch_positions[device_id], mention_end_positions[device_id]]
      mention_encodings_on_device = jnp.concatenate(
          [mention_start_encodings, mention_end_encodings], axis=-1)
      mention_encodings_on_device = np.matmul(
          mention_encodings_on_device,
          params['query_projector']['kernel'][device_id])
      mention_encodings_on_device += params['query_projector']['bias'][
          device_id]
      mention_encodings.append(mention_encodings_on_device)

    # [n_devices, n_mentions, memory_key_dim]
    mention_encodings_stacked = jnp.stack(mention_encodings)
    mention_encodings_stacked = mention_encodings_stacked.reshape(
        [self.n_devices * self.n_mentions, self.memory_key_dim])

    # Object which represents a single retrieval result with additional info.
    RetrievedMemory = collections.namedtuple('RetrievedMemory', [
        'device', 'row', 'rowwise_index', 'devicewise_index', 'global_index',
        'score', 'memory', 'entity_id', 'memory_hash',
        'memory_passage_text_entities'
    ])

    num_disallowed_per_device = [0 for _ in range(self.n_devices)]
    # Manually simulate retrieval per every query
    for query_id in range(self.n_devices * self.n_mentions):
      query = mention_encodings_stacked[query_id]
      top_memories_query = []
      # Collect retirevals for a single query on each devices separately
      for device_id in range(self.n_devices):
        top_memories_per_device = []
        for row_id in range(self.rows):
          scores = np.einsum('mh,h->m', memory_table[device_id, row_id], query)
          top_index = np.argmax(scores)
          devicewise_index = row_id * (self.table_size // self.rows) + top_index
          global_index = memory_ids[device_id, devicewise_index]
          self.assertEqual(global_index,
                           devicewise_index + device_id * self.table_size)
          if max_text_identifiers is not None:
            memory_hash = memory_identifiers[device_id, devicewise_index].item()
          else:
            memory_hash = None
          if memory_text_entities is not None:
            memory_passage_text_entities = memory_text_entities[
                device_id, devicewise_index]
          else:
            memory_passage_text_entities = None
          top_memories_per_device.append(
              RetrievedMemory(
                  device=device_id,
                  row=row_id,
                  rowwise_index=top_index,
                  devicewise_index=devicewise_index,
                  global_index=global_index,
                  score=scores[top_index].item(),
                  memory=memory_table[device_id, row_id, top_index],
                  entity_id=memory_entity_ids[device_id,
                                              devicewise_index].item(),
                  memory_hash=memory_hash,
                  memory_passage_text_entities=memory_passage_text_entities,
              ))
        # Sort by score. In case two scores are equal (likely because both
        # were considered "disallowed" we compare by entity IDs.
        top_memories_per_device.sort(
            key=lambda x: (x.score, x.entity_id), reverse=True)
        top_memories_per_device = top_memories_per_device[:self.k_top_device]
        top_memories_query.extend(top_memories_per_device)

      top_memories_query.sort(
          key=lambda x: (x.score, x.entity_id), reverse=True)
      if k_top_post_selection is not None:
        top_memories_query = top_memories_query[:k_top_post_selection]

      if max_text_identifiers is not None:
        num_current_disallowed = 0
        text_id = text_identifiers[query_id // self.n_mentions,
                                   query_id % self.n_mentions].item()
        for i in range(len(top_memories_query)):
          if top_memories_query[i].memory_hash == text_id:
            num_current_disallowed += 1

          if ((same_passage_memory_policy == 'disallow' and
               top_memories_query[i].memory_hash == text_id) or
              (same_passage_memory_policy == 'only' and
               top_memories_query[i].memory_hash != text_id)):
            top_memories_query[i] = top_memories_query[i]._replace(
                score=-_LARGE_NUMBER)
        num_disallowed_per_device[query_id //
                                  self.n_mentions] += num_current_disallowed
        top_memories_query.sort(
            key=lambda x: (x.score, x.global_index), reverse=True)

      actual_entity_ids = loss_helpers['top_entity_ids'][query_id //
                                                         self.n_mentions,
                                                         query_id %
                                                         self.n_mentions]
      actual_memory_ids = loss_helpers['top_memory_ids'][query_id //
                                                         self.n_mentions,
                                                         query_id %
                                                         self.n_mentions]
      actual_attention_weights = loss_helpers['memory_attention_weights'][
          query_id // self.n_mentions, query_id % self.n_mentions]

      # We sort retrieved results first by the attention score and then
      # by memory ID.
      p = list(range(len(actual_attention_weights)))
      # pylint: disable=cell-var-from-loop
      p.sort(
          key=lambda i: (actual_attention_weights[i], actual_memory_ids[i]),
          reverse=True)
      # pylint: enable=cell-var-from-loop
      p = np.array(p)

      actual_entity_ids = list(actual_entity_ids[p])
      actual_attention_weights = list(actual_attention_weights[p])
      actual_memory_ids = list(actual_memory_ids[p])

      expected_entity_ids = [x.entity_id for x in top_memories_query]
      self.assertSequenceEqual(expected_entity_ids, actual_entity_ids)

      expected_attention_weights = scipy.special.softmax(
          [x.score for x in top_memories_query])
      self.assertSequenceAlmostEqual(expected_attention_weights,
                                     actual_attention_weights, 5)

      expected_memory_ids = [x.global_index for x in top_memories_query]
      self.assertSequenceEqual(expected_memory_ids, actual_memory_ids)

      actual_top_text_entities = loss_helpers['memory_top_text_entities'][
          query_id // self.n_mentions, query_id % self.n_mentions]
      actual_top_text_entities = actual_top_text_entities[p]

      expected_top_text_entities = [
          x.memory_passage_text_entities for x in top_memories_query
      ]
      self.assertEqual(
          len(actual_top_text_entities), len(expected_top_text_entities))

      # Comparing `actual_top_text_entities` and `expected_top_text_entities`
      # directly is troublesome since we cannot gurantee the order for
      # retrieval results with the same attention weights. Therefore, we first
      # sort both `actual_top_text_entities` and `expected_attention_weights`
      # by the attention weight first and then by their elements.
      def sort_entities(top_text_entities):
        result = [(expected_attention_weights[i], list(top_text_entities[i]))
                  for i in range(len(top_text_entities))]
        result.sort()
        return [x[1] for x in result]

      actual_top_text_entities = sort_entities(actual_top_text_entities)
      expected_top_text_entities = sort_entities(expected_top_text_entities)

      for i in range(len(actual_top_text_entities)):
        self.assertSequenceEqual(
            list(actual_top_text_entities[i]),
            list(expected_top_text_entities[i]))

    if max_text_identifiers is not None:
      self.assertSequenceEqual(num_disallowed_per_device,
                               logging_helpers['n_disallowed'])
Пример #25
0
  def test_mention_memory_layer(self, separate_memory_values):
    """Testing memory attention layer."""

    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    static_argnums = (9) if separate_memory_values else (9, 10)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=static_argnums)

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    n_mentions = mention_start_positions.shape[-1]

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)
    # Make sure id 0 or 1 will be highest scoring
    memory_table[0] = memory_table[0] * 2.0
    memory_table[1] = memory_table[1] * -2.0
    memory_table = jnp.asarray(memory_table, dtype=self.dtype)

    memory_keys = memory_table.reshape(self.n_devices, self.rows,
                                       self.table_size // self.rows,
                                       self.memory_key_dim)

    memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices)
    if separate_memory_values:
      memory_values = memory_table.reshape(self.n_devices, self.table_size,
                                           self.memory_key_dim)
      memory_values = jax.device_put_sharded(list(memory_values), devices)
    else:
      memory_values = None

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    (encoded_output, loss_helpers, _), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,
        text_identifiers=None,
    )

    attention_weights = loss_helpers['memory_attention_weights']
    entity_ids = loss_helpers['top_entity_ids']

    normed_input = encoded_input - 1.0

    # Check input was changed
    self.assertFalse(jnp.allclose(encoded_output, normed_input))

    # Check input was not changed where it should not be
    all_indices = set(
        itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
    # Note that mention positions is the same across all of the devices
    start_indices = set(
        zip(mention_batch_positions[0].tolist(),
            mention_start_positions[0].tolist()))
    non_start_indices = all_indices.difference(start_indices)
    non_start_indices_1, non_start_indices_2 = zip(*non_start_indices)
    non_start_indices_1 = jnp.asarray(non_start_indices_1)
    non_start_indices_2 = jnp.asarray(non_start_indices_2)

    non_start_outputs = encoded_output[:, non_start_indices_1,
                                       non_start_indices_2]
    non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2]
    self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs))

    # Check shapes as expected
    self.assertSequenceEqual(
        encoded_output.shape,
        (self.n_devices, self.bsz, self.seq_len, self.input_dim))

    self.assertSequenceEqual(
        attention_weights.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    self.assertSequenceEqual(
        entity_ids.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    # Check id 0 or 1 retrieved
    self.assertTrue(
        jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1)))

    # Set some text identifiers to 0 and others to 1 so that some are binding
    text_identifiers = np.zeros((n_mentions), dtype=np.int32)
    text_identifiers[:n_mentions // 2] = 1
    text_identifiers = jax.device_put_replicated(text_identifiers, devices)

    # Initialize and run one forward pass of model
    (_, loss_helpers, logging_helpers), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,  # memory_values
        text_identifiers=text_identifiers,
    )
    attention_weights_wid = loss_helpers['memory_attention_weights']
    entity_ids_wid = loss_helpers['top_entity_ids']
    n_disallowed = logging_helpers['n_disallowed'][0]

    # Check no effect on ids
    self.assertTrue(jnp.all(entity_ids == entity_ids_wid))

    # Check id 0 or 1 have 0 scores
    text_identifiers = jnp.expand_dims(text_identifiers, -1)
    score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid
    self.assertAlmostEqual(score_masked.sum(), 0.0)

    # Check number disallowed as expected
    self.assertEqual(n_disallowed, n_mentions // 2)
Пример #26
0
    def load_memory(config: ml_collections.ConfigDict) -> Dict[str, Any]:
        """Load mention memory."""
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        process_count = jax.process_count()
        # Reduce number of loaded memory shards by this proportion. Total shards
        # must be divisible by memory_reduction * process_count.
        memory_reduction = config.get('memory_reduction', 1)
        process_index = jax.process_index()
        local_devices = jax.local_devices()

        memory_prop = config.get('memory_prop', None)
        rows = encoder_config.rows
        memory_key_dim = encoder_config.memory_key_dim

        memory_arrays = {}
        memory_component_names = [
            'memory_keys', 'memory_identifiers', 'memory_entity_ids'
        ]
        # The following arrays should be converted to integer 32 type. The rest of
        # the arrays will converted to model type (typically, bfloat16 of float32).
        memory_component_int_dtypes = {
            'memory_identifiers', 'memory_entity_ids', 'memory_text_entities'
        }
        patterns = [
            config.memory_key_pattern, config.memory_id_pattern,
            config.memory_entity_id_pattern
        ]

        if encoder_config.separate_memory_values:
            memory_component_names.append('memory_values')
            patterns.append(config.memory_value_pattern)

        if config.get('same_entity_set_retrieval_weight', 0) > 0:
            memory_component_names.append('memory_text_entities')
            patterns.append(config.memory_text_entities_pattern)

        for key, pattern in zip(memory_component_names, patterns):
            memory_arrays[key] = data_utils.load_sharded_array(
                pattern, process_count * memory_reduction, process_index)

        memory_variables = {}

        cpu_device = jax.local_devices(backend='cpu')[0]
        dtype = encoder_config.dtype
        for key, array in memory_arrays.items():
            if memory_prop is not None:
                array = array[:int(memory_prop * array.shape[0])]
            if key == 'memory_keys':
                array = array.reshape(len(local_devices), rows, -1,
                                      memory_key_dim)
            else:
                array = array.reshape((len(local_devices), -1) +
                                      array.shape[1:])
            array = jax.device_put(array, cpu_device)
            if key in memory_component_int_dtypes:
                array = jnp.asarray(array, dtype=jnp.int32)
            else:
                array = jnp.asarray(array, dtype=dtype)
            array = jax.device_put_sharded(list(array), local_devices)
            memory_variables[key] = array
        return memory_variables