def test_get_from_first_device(self): sharded = { 'a': jax.device_put_sharded( list(jnp.arange(16).reshape([jax.local_device_count(), 4])), jax.local_devices()), 'b': jax.device_put_sharded( list(jnp.arange(8).reshape([jax.local_device_count(), 2])), jax.local_devices(), ), } want = { 'a': jnp.arange(4), 'b': jnp.arange(2), } # Get zeroth device content as DeviceArray. device_arrays = utils.get_from_first_device(sharded, as_numpy=False) jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray), device_arrays) jax.tree_map(np.testing.assert_array_equal, want, device_arrays) # Get the zeroth device content as numpy arrays. numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
def __next__(self) -> types.NestedArray: try: if not self.pmapped_user: item = next(self.iterator) if self.split_fn is None: return jax.device_put(item, self.devices[0]) item_split = self.split_fn(item) return PrefetchingSplit(host=item_split.host, device=jax.device_put( item_split.device, self.devices[0])) items = itertools.islice(self.iterator, self.num_devices) items = tuple(items) if len(items) < self.num_devices: raise StopIteration if self.split_fn is None: return jax.device_put_sharded(tuple(items), self.devices) else: # ((host: x1, device: y1), ..., (host: xN, device: yN)). items_split = (self.split_fn(item) for item in items) # (host: (x1, ..., xN), device: (y1, ..., yN)). split = tree.map_structure_up_to(PrefetchingSplit(None, None), lambda *x: x, *items_split) return PrefetchingSplit(host=np.stack(split.host), device=jax.device_put_sharded( split.device, self.devices)) except StopIteration: raise except Exception: # pylint: disable=broad-except logging.exception('Error for %s', self.iterable) raise
def producer(): """Enqueues batched items from `iterable` on a given thread.""" try: # Build a new iterable for each thread. This is crucial if working with # tensorflow datasets because tf.Graph objects are thread local. it = iter(iterable) while True: items = itertools.islice(it, len(devices)) if not items: break if split_fn is None: buffer.put(jax.device_put_sharded(tuple(items), devices)) else: # ((host: x1, device: y1), ..., (host: xN, device: yN)). items_split = (split_fn(item) for item in items) # (host: (x1, ..., xN), device: (y1, ..., yN)). split = tree.map_structure_up_to( PrefetchingSplit(None, None), lambda *x: x, *items_split) buffer.put( PrefetchingSplit(host=np.stack(split.host), device=jax.device_put_sharded( split.device, devices))) except Exception as e: # pylint: disable=broad-except logging.exception('Error in producer thread for %s', iterable) producer_error.append(e) finally: buffer.put(end)
def _gen_array(self, gen_fn): array = [gen_fn() for _ in range(len(self.devices))] array_sharded = jax.device_put_sharded(array, self.devices) array_stacked = np.stack(array) array_stacked = array_stacked.reshape([-1] + list(array_stacked.shape[2:])) return array_stacked, array_sharded
def test_get_globally_consistent_batch_positions(self, seed, batch_size, n_mentions): np.random.seed(seed) mention_batch_positions_sharded = jax.device_put_sharded( list(np.random.randint(batch_size, size=(self.n_devices, n_mentions))), self.devices) fn = functools.partial( mention_utils.get_globally_consistent_batch_positions, batch_size=batch_size) # Test function in the multi-device setting (local_mention_batch_positions, global_mention_batch_positions) = jax.pmap( fn, axis_name='batch')( mention_batch_positions_sharded) for i in range(self.n_devices): self.assertArrayEqual(local_mention_batch_positions[i], mention_batch_positions_sharded[i] + batch_size * i) local_mention_batch_positions = local_mention_batch_positions.reshape(-1) for i in range(self.n_devices): self.assertArrayEqual(global_mention_batch_positions[i], local_mention_batch_positions) # Test function in the single-device setting for i in range(self.n_devices): (local_mention_batch_positions, global_mention_batch_positions) = fn(mention_batch_positions_sharded[i]) self.assertArrayEqual(local_mention_batch_positions, mention_batch_positions_sharded[i]) self.assertArrayEqual(global_mention_batch_positions, mention_batch_positions_sharded[i])
def replicate(self): """A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device. """ replicated, saved_states = [], [] devices = get_local_devices() ndevices = len(devices) for v in self: if isinstance(v, RandomState): replicated.append(jax.device_put_sharded([shard for shard in v.split(ndevices)], devices)) saved_states.append(v.value) else: replicated.append(jax.device_put_replicated(v.value, devices)) self.assign(replicated) yield visited = set() saved_states.reverse() for k, v in self.items(): if isinstance(v, TrainRef): v = v.ref assert not isinstance(v, TrainRef) if id(v) not in visited: # Careful not to reduce twice in case of a variable and a reference to it. if isinstance(v, RandomState): v.assign(saved_states.pop()) else: v.reduce(v.value) visited.add(id(v))
def replicate(self): """A context manager to use in a with statement that replicates the variables in this collection to multiple devices. Important: replicating also updates the random state in order to have a new one per device. """ global math if math is None: from brainpy import math replicated, saved_states = {}, {} x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_) sharded_x = jax.pmap(lambda x: x, axis_name='device')(x) devices = [b.device() for b in sharded_x.device_buffers] num_device = len(devices) for k, d in self.items(): if isinstance(d, math.random.RandomState): replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices) saved_states[k] = d.value else: replicated[k] = jax.device_put_replicated(d.value, devices) self.assign(replicated) yield visited = set() for k, d in self.items(): # Careful not to reduce twice in case of # a variable and a reference to it. if id(d) not in visited: if isinstance(d, math.random.RandomState): d.value = saved_states[k] else: d.value = reduce_func(d) visited.add(id(d))
def test_get_num_common_unique_items_multi_devices(self, seed: int, batch_size: int, n_mentions: int, vocab_size: int): np.random.seed(seed) batch_positions_sharded = jax.device_put_sharded( list(np.random.randint(batch_size, size=(self.n_devices, n_mentions))), self.devices) ids_sharded = jax.device_put_sharded( list(np.random.randint(vocab_size, size=(self.n_devices, n_mentions))), self.devices) fn = functools.partial( mention_utils.get_num_common_unique_items, batch_size=batch_size) actual_per_sample, actual_per_mention = jax.pmap( fn, axis_name='batch')(batch_positions_sharded, ids_sharded) batches = [] for i in range(self.n_devices): batches.append([]) for j in range(batch_size): batches[i].append(set()) for j in range(n_mentions): if ids_sharded[i, j] > 0: batches[i][batch_positions_sharded[i, j]].add(ids_sharded[i, j].item()) for d_i in range(self.n_devices): for b_i in range(batch_size): for d_j in range(self.n_devices): for b_j in range(batch_size): self.assertLen( batches[d_i][b_i].intersection(batches[d_j][b_j]), actual_per_sample[d_i, b_i, d_j * batch_size + b_j].item()) for d_i in range(self.n_devices): for m_i in range(n_mentions): b_i = batch_positions_sharded[d_i, m_i].item() for d_j in range(self.n_devices): for m_j in range(n_mentions): b_j = batch_positions_sharded[d_j, m_j].item() self.assertLen( batches[d_i][b_i].intersection(batches[d_j][b_j]), actual_per_mention[d_i, m_i, d_j * n_mentions + m_j].item())
def _array_shard( x, devices = None ): """Shards a single array over the first axis across multiple local devices.""" devices = devices or jax.local_devices() x = jnp.asarray(x) assert x.shape[0] == len(devices) return jax.device_put_sharded(list(x), devices)
def setUp(self): super().setUp() test_utils.force_multi_devices(self.n_devices) self.devices = jax.local_devices() mention_batch_positions = [ np.random.randint(self.batch_size, size=(self.n_mentions)) for _ in range(self.n_devices) ] self.mention_batch_positions_sharded = jax.device_put_sharded( mention_batch_positions, self.devices)
def run(shared_input, clients): for block in _blockify(clients, block_size): p_state = p_client_init( shared_input, jax.device_put_sharded(block.client_input, devices)) p_step_results = [] for p_batch, p_mask in block.masked_batches: p_state, p_step_result = p_client_step( p_state, jax.device_put_sharded(p_batch, devices), jax.device_put_sharded(p_mask, devices)) p_step_results.append(p_step_result) p_client_output = p_client_final(shared_input, p_state) for i in range(len(block.client_id)): if not block.client_mask[i]: continue client_output, step_results = jax.tree_util.tree_map( lambda x: x[i], # pylint: disable=cell-var-from-loop (p_client_output, p_step_results)) yield (block.client_id[i], client_output, step_results[:block.num_batches[i]])
def device_put_sharded(x): if not isinstance(x, (jnp.ndarray, np.ndarray)): return x # Later, device_put_sharded takes a sequence of tensors, one tensor for # every local device. So we split it on the zeroth (device) dimension. x = np.reshape(x, [jax.local_device_count(), -1, x.shape[2]]) x_list = np.split(x, x.shape[0], axis=0) # Squeeze out the dummy dimension. x_list = jax.tree_map(lambda y: np.squeeze(y, axis=0), x_list) # Send the sharded array in devices. return jax.device_put_sharded(x_list, jax.local_devices())
def _replicate(x, devices=None): """Replicate an object on each device.""" x = jax.numpy.array(x) if devices is None: devices = jax.local_devices() return jax.device_put_sharded(len(devices) * [x], devices)
def test_linking_layer(self): """Testing linking layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() encoded_input = jnp.ones(shape=(self.n_devices, self.bsz, self.seq_len, self.hidden_size), dtype=self.dtype) encoded_input = jax.device_put_sharded(list(encoded_input), devices) mention_batch_positions = np.random.randint(self.bsz, size=(self.n_devices, self.n_mentions)) mention_batch_positions = jax.device_put_sharded( list(mention_batch_positions), devices) mention_start_positions = np.random.randint(self.seq_len - 1, size=(self.n_devices, self.n_mentions)) mention_end_positions = mention_start_positions + 1 mention_start_positions = jax.device_put_sharded( list(mention_start_positions), devices) mention_end_positions = jax.device_put_sharded( list(mention_end_positions), devices) mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions)) mention_mask = jax.device_put_sharded(list(mention_mask), devices) mention_entity_ids = jnp.arange( self.n_devices * self.n_mentions).reshape(self.n_devices, self.n_mentions) mention_entity_ids = jax.device_put_sharded(list(mention_entity_ids), devices) model = memory_extraction_layer.MemoryExtractionLayer( memory_key_dim=self.memory_key_dim, memory_value_dim=self.memory_value_dim, dtype=self.dtype, ) pinit_with_output = jax.pmap(model.init_with_output, axis_name='batch') rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) result_dict, _ = pinit_with_output( split_rng, encoding=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, mention_entity_ids=mention_entity_ids, ) # Check shapes are as expected self.assertSequenceEqual(result_dict['memory_keys'].shape, (self.n_devices, self.n_devices * self.n_mentions, self.memory_key_dim)) self.assertSequenceEqual(result_dict['memory_values'].shape, (self.n_devices, self.n_devices * self.n_mentions, self.memory_value_dim)) self.assertSequenceEqual( result_dict['memory_mask'].shape, (self.n_devices, self.n_devices * self.n_mentions)) self.assertSequenceEqual( result_dict['memory_entity_ids'].shape, (self.n_devices, self.n_devices * self.n_mentions)) # Memory mask and entity ids should just have been all gathered self.assertTrue( jnp.all(result_dict['memory_mask'][0].reshape( self.n_devices, self.n_mentions) == mention_mask)) self.assertTrue( jnp.all(result_dict['memory_entity_ids'][0].reshape( self.n_devices, self.n_mentions) == mention_entity_ids))
def run_model(self, config, entity_vocab_size): """Initialize and run the model once, perform sanity checks.""" np.random.seed(0) # Save arrays to test retrieval saver. memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, self.devices) memory_entity_ids = memory_identifiers config['memory_entity_id_pattern'] = self.save_sharded_array( memory_entity_ids, 'memory_entity_id') memory_text = np.random.randint( config['model_config']['encoder_config']['vocab_size'], size=(self.n_devices, self.table_size, self.memory_text_length), dtype=np.int32) config['memory_text_pattern'] = self.save_sharded_array( memory_text, 'memory_text') memory_positions = np.random.randint(self.memory_text_length, size=(self.n_devices, self.table_size, 2), dtype=np.int32) config['memory_positions_pattern'] = self.save_sharded_array( memory_positions, 'memory_positions') config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_collater_fn( config) postprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_output_postprocess_fn( config) model = mention_based_entity_qa_task.MentionBasedEntityQATask.build_model( model_config) dummy_input = mention_based_entity_qa_task.MentionBasedEntityQATask.dummy_input( config) dummy_input = jax.device_put_replicated(dummy_input, self.devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, self.devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, self.devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'encoder': { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, } } def sample_batch(): processed_examples = [] for _ in range(config.per_device_batch_size): raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, entity_vocab_size=entity_vocab_size, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) processed_examples.append(processed_example) batch = stack(processed_examples) batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } return batch batch = stack([sample_batch() for _ in range(self.n_devices)]) batch = { key: jax.device_put_sharded(list(value), self.devices) for key, value in batch.items() } loss_fn = jax.pmap( mention_based_entity_qa_task.MentionBasedEntityQATask.make_loss_fn( config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {'constants': initial_variables['constants']}, batch, True, ) self.assertArrayEqual(metrics['agg']['denominator'], batch['mention_target_weights'].sum(1)) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key]) n_mentions_per_device = (config.per_device_batch_size * config.max_mention_targets) k_top_final = (encoder_config.final_k_top_post_selection or encoder_config.final_k_top_device * self.n_devices) self.assertSequenceEqual( np.array(features['memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_final, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_final, 2]) return batch, initial_variables, metrics
def bcast_local_devices(value): """Broadcasts an object to all local devices.""" devices = jax.local_devices() return jax.tree_map( lambda v: jax.device_put_sharded(len(devices) * [v], devices), value)
def _device_put_sharded(sharded_tree, devices): leaves, treedef = jax.tree_flatten(sharded_tree) n = leaves[0].shape[0] return jax.device_put_sharded([ jax.tree_unflatten(treedef, [l[i] for l in leaves]) for i in range(n) ], devices)
def build_dataset_iterator( data_root: str, split: str, dynamic_batch_size_config: config_dict.ConfigDict, online_subsampling_kwargs: dict, # pylint: disable=g-bare-generic debug: bool = False, is_training: bool = True, k_fold_split_id: Optional[int] = None, ratio_unlabeled_data_to_labeled_data: float = 0.0, use_all_labels_when_not_training: bool = False, use_dummy_adjacencies: bool = False, ): """Returns an iterator over Batches from the dataset.""" if split == 'test': use_all_labels_when_not_training = True if not is_training: ratio_unlabeled_data_to_labeled_data = 0.0 # Load the master data arrays. with LOADING_RAW_ARRAYS_LOCK: array_dict = data_utils.get_arrays( data_root, k_fold_split_id=k_fold_split_id, use_dummy_adjacencies=use_dummy_adjacencies) node_labels = array_dict['paper_label'].reshape(-1) train_indices = array_dict['train_indices'].astype(np.int32) is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32) is_train_index[train_indices] = 1 valid_indices = array_dict['valid_indices'].astype(np.int32) is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32) is_valid_index[valid_indices] = 1 is_train_or_valid_index = is_train_index + is_valid_index def sstable_to_intermediate_graph(graph): indices = tf.cast(graph.nodes['index'], tf.int32) first_index = indices[..., 0] # Add an additional absolute index, but adding offsets to authors, and # institution indices. absolute_index = graph.nodes['index'] is_author = graph.nodes['type'] == 1 absolute_index = tf.where(is_author, absolute_index + data_utils.NUM_PAPERS, absolute_index) is_institution = graph.nodes['type'] == 2 absolute_index = tf.where( is_institution, absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS, absolute_index) is_same_as_central_node = tf.math.equal(indices, first_index) input_nodes = graph.nodes graph = graph._replace( nodes={ 'one_hot_type': tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3), 'one_hot_depth': tf.one_hot(tf.cast(input_nodes['depth'], tf.int32), _MAX_DEPTH_IN_SUBGRAPH), 'year': tf.expand_dims(input_nodes['year'], axis=-1), 'label': tf.one_hot(tf.cast(input_nodes['label'], tf.int32), NUM_CLASSES), 'is_same_as_central_node': is_same_as_central_node, # Only first node in graph has a valid label. 'is_central_node': tf.one_hot(0, tf.shape(input_nodes['label'])[0]), 'index': input_nodes['index'], 'absolute_index': absolute_index, }, globals=tf.expand_dims(graph.globals, axis=-1), ) return graph ds = data_utils.get_graph_subsampling_dataset( split, array_dict, shuffle_indices=is_training, ratio_unlabeled_data_to_labeled_data= ratio_unlabeled_data_to_labeled_data, max_nodes=dynamic_batch_size_config.n_node - 1, # Keep space for pads. max_edges=dynamic_batch_size_config.n_edge, **online_subsampling_kwargs) if debug: ds = ds.take(50) ds = ds.map(sstable_to_intermediate_graph, num_parallel_calls=tf.data.experimental.AUTOTUNE) if is_training: ds = ds.shard(jax.process_count(), jax.process_index()) ds = ds.shuffle(buffer_size=1 if debug else 128) ds = ds.repeat() ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE) np_ds = iter(tfds.as_numpy(ds)) batched_np_ds = batching_utils.dynamically_batch( np_ds, **dynamic_batch_size_config, ) def intermediate_graph_to_batch(graph): central_node_mask = graph.nodes['is_central_node'] label = graph.nodes['label'] node_indices = graph.nodes['index'] absolute_indices = graph.nodes['absolute_index'] ### Construct label as a feature for non-central nodes. # First do a lookup with node indices, with a np.minimum to ensure we do not # index out of bounds due to num_authors being larger than num_papers. is_same_as_central_node = graph.nodes['is_same_as_central_node'] capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1) label_as_feature = node_labels[capped_indices] # Nodes which are not in train set should get `num_classes` label. # Nodes in test set or non-arXiv nodes have -1 or nan labels. # Mask out invalid labels and non-papers. use_label_as_feature = np.logical_and( label_as_feature >= 0, graph.nodes['one_hot_type'][..., 0]) if split == 'train' or not use_all_labels_when_not_training: # Mask out validation papers and non-arxiv papers who # got labels from fusing with arxiv papers. use_label_as_feature = np.logical_and( is_train_index[capped_indices], use_label_as_feature) label_as_feature = np.where(use_label_as_feature, label_as_feature, NUM_CLASSES) # Mask out central node label in case it appears again. label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES, label_as_feature) # Nodes which are not papers get `NUM_CLASSES+1` label. label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0], label_as_feature, NUM_CLASSES + 1) nodes = { 'label_as_feature': label_as_feature, 'year': graph.nodes['year'], 'bitstring_year': _get_bitstring_year_representation(graph.nodes['year']), 'one_hot_type': graph.nodes['one_hot_type'], 'one_hot_depth': graph.nodes['one_hot_depth'], } graph = graph._replace( nodes=nodes, globals={}, ) is_train_or_valid_node = np.logical_and( is_train_or_valid_index[capped_indices], graph.nodes['one_hot_type'][..., 0]) if is_training: label_mask = np.logical_and(central_node_mask, is_train_or_valid_node) else: # `label_mask` is used to index into valid central nodes by prediction # calculator. Since that computation is only done when not training, and # at that time we are guaranteed all central nodes have valid labels, # we just set label_mask = central_node_mask when not training. label_mask = central_node_mask batch = Batch(graph=graph, node_labels=label, central_node_mask=central_node_mask, label_mask=label_mask, node_indices=node_indices, absolute_node_indices=absolute_indices) # Transform integers into one-hots. batch = _add_one_hot_features_to_batch(batch) # Gather PCA features. return _add_embeddings_to_batch(batch, array_dict['bert_pca_129']) batch_list = [] for batch in batched_np_ds: with jax.profiler.StepTraceAnnotation('batch_postprocessing'): batch = intermediate_graph_to_batch(batch) if is_training: batch_list.append(batch) if len(batch_list) == jax.local_device_count(): yield jax.device_put_sharded(batch_list, jax.local_devices()) batch_list = [] else: yield batch
def run_model(self, config, entity_vocab_size): """Initialize and run the model once, perform sanity checks.""" np.random.seed(0) # Save arrays to test retrieval saver. memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, self.devices) memory_entity_ids = memory_identifiers config['memory_entity_id_pattern'] = self.save_sharded_array( memory_entity_ids, 'memory_entity_id') memory_text = np.random.randint( config['model_config']['encoder_config']['vocab_size'], size=(self.n_devices, self.table_size, self.memory_text_length), dtype=np.int32) config['memory_text_pattern'] = self.save_sharded_array( memory_text, 'memory_text') memory_positions = np.random.randint(self.memory_text_length, size=(self.n_devices, self.table_size, 2), dtype=np.int32) config['memory_positions_pattern'] = self.save_sharded_array( memory_positions, 'memory_positions') config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn( config) postprocess_fn = mention_memory_task.MentionMemoryTask.make_output_postprocess_fn( config) model = mention_memory_task.MentionMemoryTask.build_model(model_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, self.devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, self.devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, self.devices) # `memory_text_entities` are assumed to contain unique IDs in the last dim. memory_text_entities = np.zeros( (self.n_devices, self.table_size, encoder_config.n_memory_text_entities), np.int32) for device_index in range(self.n_devices): for t_index in range(self.table_size): current_text_entities = np.random.choice( entity_vocab_size, size=(min(encoder_config.n_memory_text_entities, entity_vocab_size)), replace=False) memory_text_entities[device_index, t_index, :len(current_text_entities )] = current_text_entities memory_text_entities = jax.device_put_sharded( list(memory_text_entities), self.devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'encoder': { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } } def sample_batch(): processed_examples = [] for _ in range(config.per_device_batch_size): raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, entity_vocab_size=entity_vocab_size, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) processed_examples.append(processed_example) batch = stack(processed_examples) batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } text_ids = batch['text_ids'] for i in range(config.per_device_batch_size): for j in range(config.max_mlm_targets): if batch['mlm_target_weights'][i, j] > 0: text_ids[i, batch['mlm_target_positions'][ i, j]] = batch['mlm_target_ids'][i, j] mention_batch_positions = batch['mention_batch_positions'] text_identifiers = batch['text_identifiers'].astype( np.int32).tolist() expected_text_identifiers = [ mention_preprocess_utils.text_hash( text_ids[mention_batch_positions[index]]).astype(np.int32) for index in range(len(mention_batch_positions)) ] self.assertSequenceEqual(text_identifiers, expected_text_identifiers) return batch batch = stack([sample_batch() for _ in range(self.n_devices)]) batch = { key: jax.device_put_sharded(list(value), self.devices) for key, value in batch.items() } loss_fn = jax.pmap( mention_memory_task.MentionMemoryTask.make_loss_fn(config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {'constants': initial_variables['constants']}, batch, True, ) metrics_per_first_device = jax.tree_map(lambda x: x[0], metrics) self.assertEqual(metrics_per_first_device['mlm']['denominator'], batch['mlm_target_weights'][0].sum()) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key]) n_mentions_per_device = (config.per_device_batch_size * config.max_mentions) if config.save_k_retrieval is not None: k_top_saved = min(config.save_k_retrieval, encoder_config.k_top_post_selection) else: k_top_saved = (encoder_config.k_top_post_selection or encoder_config.k_top_device * self.n_devices) self.assertSequenceEqual( np.array(features['memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_saved, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_saved, 2]) if encoder_config.get('num_intermediate_layers') is not None: self.assertSequenceEqual( np.array(features['second_memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_saved, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['second_memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_saved, 2]) return batch, initial_variables, metrics
def split_and_put(x: jnp.ndarray) -> jnp.ndarray: return jax.device_put_sharded(np.split(x[:self._dataset_size], len(device)), devices=device)
def _replicate(x): """Replicate an object on each device.""" x = jnp.array(x) return jax.device_put_sharded(len(devices) * [x], devices)
def test_memory_attention_backward(self): test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) pinit = jax.pmap( model.init, axis_name='batch', static_broadcasted_argnums=(9, 10)) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_table = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_table_sharded = jax.device_put_sharded(list(memory_table), devices) memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids initial_parameters = pinit( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_table_sharded, memory_identifiers, memory_entity_ids, True, # deterministic None, # memory_values text_identifiers=None, ) def step_fn( params, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys, memory_identifiers, memory_entity_ids, ): def loss_fn(params): encoded_output, _, _ = model.apply( {'params': params}, rngs=None, encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_keys, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, deterministic=True, text_identifiers=None, ) return encoded_output.sum() loss, grad = jax.value_and_grad(loss_fn)(params) return loss, grad pstep = jax.pmap(step_fn, axis_name='batch') _ = pstep( initial_parameters['params'], encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_table_sharded, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, )
def replicate_in_all_devices(nest: N, devices: Optional[Sequence[jax.xla.Device]] = None ) -> N: """Replicate array nest in all available devices.""" devices = devices or jax.local_devices() return jax.device_put_sharded([nest] * len(devices), devices)
def test_compare_retrievals_with_numpy(self, seed, k_top_post_selection, max_text_identifiers, same_passage_memory_policy): """Test whether retrieval results are correct.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() n_text_entities_per_memory = 3 model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=k_top_post_selection, splits=self.splits, dtype=self.dtype) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=(9, 10, 13)) rng = jax.random.PRNGKey(seed) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jax.random.uniform( rng, shape=(self.n_devices, self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) mention_batch_positions = jax.random.randint( rng, minval=0, maxval=self.bsz, shape=(self.n_devices, self.n_mentions)) mention_start_positions = jax.random.randint( rng, minval=0, maxval=self.seq_len, shape=(self.n_devices, self.n_mentions)) mention_end_positions = mention_start_positions mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions)) memory_table = jax.random.uniform( rng, shape=(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim)) memory_entity_ids = jax.random.randint( rng, minval=0, maxval=self.entity_vocab_size, shape=(self.n_devices, self.table_size)) if max_text_identifiers is not None: memory_identifiers = jax.random.randint( rng, minval=0, maxval=max_text_identifiers, shape=(self.n_devices, self.table_size)) text_identifiers = jax.random.randint( rng, minval=0, maxval=max_text_identifiers, shape=(self.n_devices, self.n_mentions)) else: text_identifiers = None if n_text_entities_per_memory is not None: memory_text_entities = jax.random.randint( rng, minval=0, maxval=self.entity_vocab_size, shape=(self.n_devices, self.table_size, n_text_entities_per_memory)) else: memory_text_entities = None encoded_input_sharded = jax.device_put_sharded(list(encoded_input), devices) mention_batch_positions_sharded = jax.device_put_sharded( list(mention_batch_positions), devices) mention_start_positions_sharded = jax.device_put_sharded( list(mention_start_positions), devices) mention_end_positions_sharded = jax.device_put_sharded( list(mention_end_positions), devices) mention_mask_sharded = jax.device_put_sharded(list(mention_mask), devices) memory_table_sharded = jax.device_put_sharded(list(memory_table), devices) memory_entity_ids_sharded = jax.device_put_sharded( list(memory_entity_ids), devices) if max_text_identifiers is not None: memory_identifiers_sharded = jax.device_put_sharded( list(memory_identifiers), devices) text_identifiers_sharded = jax.device_put_sharded( list(text_identifiers), devices) else: memory_identifiers_sharded = None text_identifiers_sharded = None if memory_text_entities is not None: memory_text_entities_sharded = jax.device_put_sharded( list(memory_text_entities), devices) else: memory_text_entities_sharded = None memory_ids = jnp.arange(self.n_devices * self.table_size) memory_ids = memory_ids.reshape(self.n_devices, self.table_size) (_, loss_helpers, logging_helpers), params = pinit_with_output( split_rng, encoded_input_sharded, mention_batch_positions_sharded, mention_start_positions_sharded, mention_end_positions_sharded, mention_mask_sharded, memory_table_sharded, memory_identifiers_sharded, memory_entity_ids_sharded, True, None, # memory_values text_identifiers_sharded, memory_text_entities_sharded, same_passage_memory_policy, ) params = params.unfreeze()['params'] mention_encodings = [] for device_id in range(self.n_devices): mention_start_encodings = encoded_input[device_id][ mention_batch_positions[device_id], mention_start_positions[device_id]] mention_end_encodings = encoded_input[device_id][ mention_batch_positions[device_id], mention_end_positions[device_id]] mention_encodings_on_device = jnp.concatenate( [mention_start_encodings, mention_end_encodings], axis=-1) mention_encodings_on_device = np.matmul( mention_encodings_on_device, params['query_projector']['kernel'][device_id]) mention_encodings_on_device += params['query_projector']['bias'][ device_id] mention_encodings.append(mention_encodings_on_device) # [n_devices, n_mentions, memory_key_dim] mention_encodings_stacked = jnp.stack(mention_encodings) mention_encodings_stacked = mention_encodings_stacked.reshape( [self.n_devices * self.n_mentions, self.memory_key_dim]) # Object which represents a single retrieval result with additional info. RetrievedMemory = collections.namedtuple('RetrievedMemory', [ 'device', 'row', 'rowwise_index', 'devicewise_index', 'global_index', 'score', 'memory', 'entity_id', 'memory_hash', 'memory_passage_text_entities' ]) num_disallowed_per_device = [0 for _ in range(self.n_devices)] # Manually simulate retrieval per every query for query_id in range(self.n_devices * self.n_mentions): query = mention_encodings_stacked[query_id] top_memories_query = [] # Collect retirevals for a single query on each devices separately for device_id in range(self.n_devices): top_memories_per_device = [] for row_id in range(self.rows): scores = np.einsum('mh,h->m', memory_table[device_id, row_id], query) top_index = np.argmax(scores) devicewise_index = row_id * (self.table_size // self.rows) + top_index global_index = memory_ids[device_id, devicewise_index] self.assertEqual(global_index, devicewise_index + device_id * self.table_size) if max_text_identifiers is not None: memory_hash = memory_identifiers[device_id, devicewise_index].item() else: memory_hash = None if memory_text_entities is not None: memory_passage_text_entities = memory_text_entities[ device_id, devicewise_index] else: memory_passage_text_entities = None top_memories_per_device.append( RetrievedMemory( device=device_id, row=row_id, rowwise_index=top_index, devicewise_index=devicewise_index, global_index=global_index, score=scores[top_index].item(), memory=memory_table[device_id, row_id, top_index], entity_id=memory_entity_ids[device_id, devicewise_index].item(), memory_hash=memory_hash, memory_passage_text_entities=memory_passage_text_entities, )) # Sort by score. In case two scores are equal (likely because both # were considered "disallowed" we compare by entity IDs. top_memories_per_device.sort( key=lambda x: (x.score, x.entity_id), reverse=True) top_memories_per_device = top_memories_per_device[:self.k_top_device] top_memories_query.extend(top_memories_per_device) top_memories_query.sort( key=lambda x: (x.score, x.entity_id), reverse=True) if k_top_post_selection is not None: top_memories_query = top_memories_query[:k_top_post_selection] if max_text_identifiers is not None: num_current_disallowed = 0 text_id = text_identifiers[query_id // self.n_mentions, query_id % self.n_mentions].item() for i in range(len(top_memories_query)): if top_memories_query[i].memory_hash == text_id: num_current_disallowed += 1 if ((same_passage_memory_policy == 'disallow' and top_memories_query[i].memory_hash == text_id) or (same_passage_memory_policy == 'only' and top_memories_query[i].memory_hash != text_id)): top_memories_query[i] = top_memories_query[i]._replace( score=-_LARGE_NUMBER) num_disallowed_per_device[query_id // self.n_mentions] += num_current_disallowed top_memories_query.sort( key=lambda x: (x.score, x.global_index), reverse=True) actual_entity_ids = loss_helpers['top_entity_ids'][query_id // self.n_mentions, query_id % self.n_mentions] actual_memory_ids = loss_helpers['top_memory_ids'][query_id // self.n_mentions, query_id % self.n_mentions] actual_attention_weights = loss_helpers['memory_attention_weights'][ query_id // self.n_mentions, query_id % self.n_mentions] # We sort retrieved results first by the attention score and then # by memory ID. p = list(range(len(actual_attention_weights))) # pylint: disable=cell-var-from-loop p.sort( key=lambda i: (actual_attention_weights[i], actual_memory_ids[i]), reverse=True) # pylint: enable=cell-var-from-loop p = np.array(p) actual_entity_ids = list(actual_entity_ids[p]) actual_attention_weights = list(actual_attention_weights[p]) actual_memory_ids = list(actual_memory_ids[p]) expected_entity_ids = [x.entity_id for x in top_memories_query] self.assertSequenceEqual(expected_entity_ids, actual_entity_ids) expected_attention_weights = scipy.special.softmax( [x.score for x in top_memories_query]) self.assertSequenceAlmostEqual(expected_attention_weights, actual_attention_weights, 5) expected_memory_ids = [x.global_index for x in top_memories_query] self.assertSequenceEqual(expected_memory_ids, actual_memory_ids) actual_top_text_entities = loss_helpers['memory_top_text_entities'][ query_id // self.n_mentions, query_id % self.n_mentions] actual_top_text_entities = actual_top_text_entities[p] expected_top_text_entities = [ x.memory_passage_text_entities for x in top_memories_query ] self.assertEqual( len(actual_top_text_entities), len(expected_top_text_entities)) # Comparing `actual_top_text_entities` and `expected_top_text_entities` # directly is troublesome since we cannot gurantee the order for # retrieval results with the same attention weights. Therefore, we first # sort both `actual_top_text_entities` and `expected_attention_weights` # by the attention weight first and then by their elements. def sort_entities(top_text_entities): result = [(expected_attention_weights[i], list(top_text_entities[i])) for i in range(len(top_text_entities))] result.sort() return [x[1] for x in result] actual_top_text_entities = sort_entities(actual_top_text_entities) expected_top_text_entities = sort_entities(expected_top_text_entities) for i in range(len(actual_top_text_entities)): self.assertSequenceEqual( list(actual_top_text_entities[i]), list(expected_top_text_entities[i])) if max_text_identifiers is not None: self.assertSequenceEqual(num_disallowed_per_device, logging_helpers['n_disallowed'])
def test_mention_memory_layer(self, separate_memory_values): """Testing memory attention layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) static_argnums = (9) if separate_memory_values else (9, 10) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=static_argnums) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) n_mentions = mention_start_positions.shape[-1] mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) # Make sure id 0 or 1 will be highest scoring memory_table[0] = memory_table[0] * 2.0 memory_table[1] = memory_table[1] * -2.0 memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_keys = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices) if separate_memory_values: memory_values = memory_table.reshape(self.n_devices, self.table_size, self.memory_key_dim) memory_values = jax.device_put_sharded(list(memory_values), devices) else: memory_values = None memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids (encoded_output, loss_helpers, _), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, text_identifiers=None, ) attention_weights = loss_helpers['memory_attention_weights'] entity_ids = loss_helpers['top_entity_ids'] normed_input = encoded_input - 1.0 # Check input was changed self.assertFalse(jnp.allclose(encoded_output, normed_input)) # Check input was not changed where it should not be all_indices = set( itertools.product(np.arange(self.bsz), np.arange(self.seq_len))) # Note that mention positions is the same across all of the devices start_indices = set( zip(mention_batch_positions[0].tolist(), mention_start_positions[0].tolist())) non_start_indices = all_indices.difference(start_indices) non_start_indices_1, non_start_indices_2 = zip(*non_start_indices) non_start_indices_1 = jnp.asarray(non_start_indices_1) non_start_indices_2 = jnp.asarray(non_start_indices_2) non_start_outputs = encoded_output[:, non_start_indices_1, non_start_indices_2] non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2] self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs)) # Check shapes as expected self.assertSequenceEqual( encoded_output.shape, (self.n_devices, self.bsz, self.seq_len, self.input_dim)) self.assertSequenceEqual( attention_weights.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) self.assertSequenceEqual( entity_ids.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) # Check id 0 or 1 retrieved self.assertTrue( jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1))) # Set some text identifiers to 0 and others to 1 so that some are binding text_identifiers = np.zeros((n_mentions), dtype=np.int32) text_identifiers[:n_mentions // 2] = 1 text_identifiers = jax.device_put_replicated(text_identifiers, devices) # Initialize and run one forward pass of model (_, loss_helpers, logging_helpers), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, # memory_values text_identifiers=text_identifiers, ) attention_weights_wid = loss_helpers['memory_attention_weights'] entity_ids_wid = loss_helpers['top_entity_ids'] n_disallowed = logging_helpers['n_disallowed'][0] # Check no effect on ids self.assertTrue(jnp.all(entity_ids == entity_ids_wid)) # Check id 0 or 1 have 0 scores text_identifiers = jnp.expand_dims(text_identifiers, -1) score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid self.assertAlmostEqual(score_masked.sum(), 0.0) # Check number disallowed as expected self.assertEqual(n_disallowed, n_mentions // 2)
def load_memory(config: ml_collections.ConfigDict) -> Dict[str, Any]: """Load mention memory.""" model_config = config.model_config encoder_config = model_config.encoder_config process_count = jax.process_count() # Reduce number of loaded memory shards by this proportion. Total shards # must be divisible by memory_reduction * process_count. memory_reduction = config.get('memory_reduction', 1) process_index = jax.process_index() local_devices = jax.local_devices() memory_prop = config.get('memory_prop', None) rows = encoder_config.rows memory_key_dim = encoder_config.memory_key_dim memory_arrays = {} memory_component_names = [ 'memory_keys', 'memory_identifiers', 'memory_entity_ids' ] # The following arrays should be converted to integer 32 type. The rest of # the arrays will converted to model type (typically, bfloat16 of float32). memory_component_int_dtypes = { 'memory_identifiers', 'memory_entity_ids', 'memory_text_entities' } patterns = [ config.memory_key_pattern, config.memory_id_pattern, config.memory_entity_id_pattern ] if encoder_config.separate_memory_values: memory_component_names.append('memory_values') patterns.append(config.memory_value_pattern) if config.get('same_entity_set_retrieval_weight', 0) > 0: memory_component_names.append('memory_text_entities') patterns.append(config.memory_text_entities_pattern) for key, pattern in zip(memory_component_names, patterns): memory_arrays[key] = data_utils.load_sharded_array( pattern, process_count * memory_reduction, process_index) memory_variables = {} cpu_device = jax.local_devices(backend='cpu')[0] dtype = encoder_config.dtype for key, array in memory_arrays.items(): if memory_prop is not None: array = array[:int(memory_prop * array.shape[0])] if key == 'memory_keys': array = array.reshape(len(local_devices), rows, -1, memory_key_dim) else: array = array.reshape((len(local_devices), -1) + array.shape[1:]) array = jax.device_put(array, cpu_device) if key in memory_component_int_dtypes: array = jnp.asarray(array, dtype=jnp.int32) else: array = jnp.asarray(array, dtype=dtype) array = jax.device_put_sharded(list(array), local_devices) memory_variables[key] = array return memory_variables