def setUp(self): super().setUp() test_utils.force_multi_devices(self.n_devices) self.devices = jax.local_devices() mention_batch_positions = [ np.random.randint(self.batch_size, size=(self.n_mentions)) for _ in range(self.n_devices) ] self.mention_batch_positions_sharded = jax.device_put_sharded( mention_batch_positions, self.devices)
def setUp(self): super().setUp() test_utils.force_multi_devices(self.n_devices) self.devices = jax.local_devices() # pylint: disable=g-long-lambda (self.mention_encodings_stacked, self.mention_encodings_sharded ) = self._gen_array(lambda: 10.0 * np.random.random( (self.n_mentions, self.hidden_size))) (self.mention_target_ids_stacked, self.mention_target_ids_sharded ) = self._gen_array(lambda: np.random.randint(self.entity_vocab_size, size=(self.n_mentions))) (self.mention_batch_positions_stacked, self.mention_batch_positions_sharded ) = self._gen_array(lambda: np.random.randint(self.batch_size, size=(self.n_mentions))) (self.mention_target_is_masked_stacked, self.mention_target_is_masked_sharded) = self._gen_array( lambda: np.random.randint(2, size=(self.n_mentions)))
def test_linking_layer(self): """Testing linking layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() encoded_input = jnp.ones(shape=(self.n_devices, self.bsz, self.seq_len, self.hidden_size), dtype=self.dtype) encoded_input = jax.device_put_sharded(list(encoded_input), devices) mention_batch_positions = np.random.randint(self.bsz, size=(self.n_devices, self.n_mentions)) mention_batch_positions = jax.device_put_sharded( list(mention_batch_positions), devices) mention_start_positions = np.random.randint(self.seq_len - 1, size=(self.n_devices, self.n_mentions)) mention_end_positions = mention_start_positions + 1 mention_start_positions = jax.device_put_sharded( list(mention_start_positions), devices) mention_end_positions = jax.device_put_sharded( list(mention_end_positions), devices) mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions)) mention_mask = jax.device_put_sharded(list(mention_mask), devices) mention_entity_ids = jnp.arange( self.n_devices * self.n_mentions).reshape(self.n_devices, self.n_mentions) mention_entity_ids = jax.device_put_sharded(list(mention_entity_ids), devices) model = memory_extraction_layer.MemoryExtractionLayer( memory_key_dim=self.memory_key_dim, memory_value_dim=self.memory_value_dim, dtype=self.dtype, ) pinit_with_output = jax.pmap(model.init_with_output, axis_name='batch') rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) result_dict, _ = pinit_with_output( split_rng, encoding=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, mention_entity_ids=mention_entity_ids, ) # Check shapes are as expected self.assertSequenceEqual(result_dict['memory_keys'].shape, (self.n_devices, self.n_devices * self.n_mentions, self.memory_key_dim)) self.assertSequenceEqual(result_dict['memory_values'].shape, (self.n_devices, self.n_devices * self.n_mentions, self.memory_value_dim)) self.assertSequenceEqual( result_dict['memory_mask'].shape, (self.n_devices, self.n_devices * self.n_mentions)) self.assertSequenceEqual( result_dict['memory_entity_ids'].shape, (self.n_devices, self.n_devices * self.n_mentions)) # Memory mask and entity ids should just have been all gathered self.assertTrue( jnp.all(result_dict['memory_mask'][0].reshape( self.n_devices, self.n_mentions) == mention_mask)) self.assertTrue( jnp.all(result_dict['memory_entity_ids'][0].reshape( self.n_devices, self.n_mentions) == mention_entity_ids))
def setUp(self): super().setUp() test_utils.force_multi_devices(self.n_devices) self.devices = jax.local_devices()
def test_multi_node_training(self): test_utils.force_multi_devices(8) trainer.train(self.test_config)
def test_model_shape( self, separate_memory_values=False, num_intermediate_layers=None, ): """Test loss function runs and produces expected values.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config['model_config']['encoder_config'][ 'num_intermediate_layers'] = num_intermediate_layers config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn( config) test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } batch = { key: jax.device_put_replicated(value, devices) for key, value in batch.items() } def model_apply(*args, **kwargs): return model.apply(*args, method=model.forward, **kwargs) papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2)) encoded_output, loss_helpers, _ = papply( { 'params': initial_variables['params'], 'constants': initial_variables['constants'], }, batch, True, ) self.assertEqual( encoded_output.shape, (self.n_devices, config.per_device_batch_size, encoder_config.max_length, encoder_config.hidden_size)) memory_value_dim = encoder_config.memory_value_dim memory_key_dim = encoder_config.memory_key_dim memory_size = memory_value_dim if memory_value_dim else memory_key_dim self.assertEqual(loss_helpers['target_mention_encodings'].shape, (self.n_devices, config.max_mention_targets * config.per_device_batch_size, memory_size))
def test_load_weights(self, separate_memory_values=False, memory_only=False): """Test saving and loading model recovers original parameters.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config = ml_collections.ConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } n_shards = 4 tempdir_obj = self.create_tempdir() tempdir = tempdir_obj.full_path memory_key_base = os.path.join(tempdir, 'memory_keys') memory_value_base = os.path.join(tempdir, 'memory_values') memory_id_base = os.path.join(tempdir, 'memory_id') memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id') memory_text_entities_base = os.path.join(tempdir, 'memory_text_entities') unreplicated_variables = jax_utils.unreplicate(initial_variables) unreplicated_variables['params'] = unreplicated_variables[ 'params'].unfreeze() if memory_only: load_weights = 'memory_only' else: load_weights = os.path.join(tempdir, 'weights') checkpoint_utils.save_weights(load_weights, unreplicated_variables['params']) memory_keys = initial_variables['constants']['memory_keys'] memory_keys = memory_keys.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_values = initial_variables['constants']['memory_values'] memory_values = memory_values.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_ids = initial_variables['constants'][ 'memory_identifiers'].reshape(n_shards, -1) memory_entity_ids = initial_variables['constants'][ 'memory_entity_ids'].reshape(n_shards, -1) memory_text_entities = initial_variables['constants'][ 'memory_text_entities'].reshape( n_shards, -1, encoder_config.n_memory_text_entities) for shard in range(n_shards): np.save(memory_key_base + str(shard), memory_keys[shard]) np.save(memory_value_base + str(shard), memory_values[shard]) np.save(memory_id_base + str(shard), memory_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_text_entities_base + str(shard), memory_text_entities[shard]) config.memory_key_pattern = memory_key_base + '*' config.memory_value_pattern = memory_value_base + '*' config.memory_id_pattern = memory_id_base + '*' config.memory_entity_id_pattern = memory_entity_id_base + '*' config.memory_text_entities_pattern = memory_text_entities_base + '*' config.load_weights = load_weights loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights( config) arrayeq = lambda x, y: jnp.all(x == y) constants = { key: value for key, value in initial_variables['constants'].items() if not (key == 'memory_values' and not separate_memory_values) } comparison_variables = {'constants': constants} if not memory_only: comparison_variables['params'] = initial_variables[ 'params'].unfreeze() self.assertTrue( jax.tree_map(arrayeq, loaded_variables, comparison_variables))
def test_mention_memory_layer(self, separate_memory_values): """Testing memory attention layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) static_argnums = (9) if separate_memory_values else (9, 10) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=static_argnums) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) n_mentions = mention_start_positions.shape[-1] mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) # Make sure id 0 or 1 will be highest scoring memory_table[0] = memory_table[0] * 2.0 memory_table[1] = memory_table[1] * -2.0 memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_keys = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices) if separate_memory_values: memory_values = memory_table.reshape(self.n_devices, self.table_size, self.memory_key_dim) memory_values = jax.device_put_sharded(list(memory_values), devices) else: memory_values = None memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids (encoded_output, loss_helpers, _), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, text_identifiers=None, ) attention_weights = loss_helpers['memory_attention_weights'] entity_ids = loss_helpers['top_entity_ids'] normed_input = encoded_input - 1.0 # Check input was changed self.assertFalse(jnp.allclose(encoded_output, normed_input)) # Check input was not changed where it should not be all_indices = set( itertools.product(np.arange(self.bsz), np.arange(self.seq_len))) # Note that mention positions is the same across all of the devices start_indices = set( zip(mention_batch_positions[0].tolist(), mention_start_positions[0].tolist())) non_start_indices = all_indices.difference(start_indices) non_start_indices_1, non_start_indices_2 = zip(*non_start_indices) non_start_indices_1 = jnp.asarray(non_start_indices_1) non_start_indices_2 = jnp.asarray(non_start_indices_2) non_start_outputs = encoded_output[:, non_start_indices_1, non_start_indices_2] non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2] self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs)) # Check shapes as expected self.assertSequenceEqual( encoded_output.shape, (self.n_devices, self.bsz, self.seq_len, self.input_dim)) self.assertSequenceEqual( attention_weights.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) self.assertSequenceEqual( entity_ids.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) # Check id 0 or 1 retrieved self.assertTrue( jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1))) # Set some text identifiers to 0 and others to 1 so that some are binding text_identifiers = np.zeros((n_mentions), dtype=np.int32) text_identifiers[:n_mentions // 2] = 1 text_identifiers = jax.device_put_replicated(text_identifiers, devices) # Initialize and run one forward pass of model (_, loss_helpers, logging_helpers), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, # memory_values text_identifiers=text_identifiers, ) attention_weights_wid = loss_helpers['memory_attention_weights'] entity_ids_wid = loss_helpers['top_entity_ids'] n_disallowed = logging_helpers['n_disallowed'][0] # Check no effect on ids self.assertTrue(jnp.all(entity_ids == entity_ids_wid)) # Check id 0 or 1 have 0 scores text_identifiers = jnp.expand_dims(text_identifiers, -1) score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid self.assertAlmostEqual(score_masked.sum(), 0.0) # Check number disallowed as expected self.assertEqual(n_disallowed, n_mentions // 2)
def test_compare_retrievals_with_numpy(self, seed, k_top_post_selection, max_text_identifiers, same_passage_memory_policy): """Test whether retrieval results are correct.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() n_text_entities_per_memory = 3 model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=k_top_post_selection, splits=self.splits, dtype=self.dtype) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=(9, 10, 13)) rng = jax.random.PRNGKey(seed) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jax.random.uniform( rng, shape=(self.n_devices, self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) mention_batch_positions = jax.random.randint( rng, minval=0, maxval=self.bsz, shape=(self.n_devices, self.n_mentions)) mention_start_positions = jax.random.randint( rng, minval=0, maxval=self.seq_len, shape=(self.n_devices, self.n_mentions)) mention_end_positions = mention_start_positions mention_mask = jnp.ones(shape=(self.n_devices, self.n_mentions)) memory_table = jax.random.uniform( rng, shape=(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim)) memory_entity_ids = jax.random.randint( rng, minval=0, maxval=self.entity_vocab_size, shape=(self.n_devices, self.table_size)) if max_text_identifiers is not None: memory_identifiers = jax.random.randint( rng, minval=0, maxval=max_text_identifiers, shape=(self.n_devices, self.table_size)) text_identifiers = jax.random.randint( rng, minval=0, maxval=max_text_identifiers, shape=(self.n_devices, self.n_mentions)) else: text_identifiers = None if n_text_entities_per_memory is not None: memory_text_entities = jax.random.randint( rng, minval=0, maxval=self.entity_vocab_size, shape=(self.n_devices, self.table_size, n_text_entities_per_memory)) else: memory_text_entities = None encoded_input_sharded = jax.device_put_sharded(list(encoded_input), devices) mention_batch_positions_sharded = jax.device_put_sharded( list(mention_batch_positions), devices) mention_start_positions_sharded = jax.device_put_sharded( list(mention_start_positions), devices) mention_end_positions_sharded = jax.device_put_sharded( list(mention_end_positions), devices) mention_mask_sharded = jax.device_put_sharded(list(mention_mask), devices) memory_table_sharded = jax.device_put_sharded(list(memory_table), devices) memory_entity_ids_sharded = jax.device_put_sharded( list(memory_entity_ids), devices) if max_text_identifiers is not None: memory_identifiers_sharded = jax.device_put_sharded( list(memory_identifiers), devices) text_identifiers_sharded = jax.device_put_sharded( list(text_identifiers), devices) else: memory_identifiers_sharded = None text_identifiers_sharded = None if memory_text_entities is not None: memory_text_entities_sharded = jax.device_put_sharded( list(memory_text_entities), devices) else: memory_text_entities_sharded = None memory_ids = jnp.arange(self.n_devices * self.table_size) memory_ids = memory_ids.reshape(self.n_devices, self.table_size) (_, loss_helpers, logging_helpers), params = pinit_with_output( split_rng, encoded_input_sharded, mention_batch_positions_sharded, mention_start_positions_sharded, mention_end_positions_sharded, mention_mask_sharded, memory_table_sharded, memory_identifiers_sharded, memory_entity_ids_sharded, True, None, # memory_values text_identifiers_sharded, memory_text_entities_sharded, same_passage_memory_policy, ) params = params.unfreeze()['params'] mention_encodings = [] for device_id in range(self.n_devices): mention_start_encodings = encoded_input[device_id][ mention_batch_positions[device_id], mention_start_positions[device_id]] mention_end_encodings = encoded_input[device_id][ mention_batch_positions[device_id], mention_end_positions[device_id]] mention_encodings_on_device = jnp.concatenate( [mention_start_encodings, mention_end_encodings], axis=-1) mention_encodings_on_device = np.matmul( mention_encodings_on_device, params['query_projector']['kernel'][device_id]) mention_encodings_on_device += params['query_projector']['bias'][ device_id] mention_encodings.append(mention_encodings_on_device) # [n_devices, n_mentions, memory_key_dim] mention_encodings_stacked = jnp.stack(mention_encodings) mention_encodings_stacked = mention_encodings_stacked.reshape( [self.n_devices * self.n_mentions, self.memory_key_dim]) # Object which represents a single retrieval result with additional info. RetrievedMemory = collections.namedtuple('RetrievedMemory', [ 'device', 'row', 'rowwise_index', 'devicewise_index', 'global_index', 'score', 'memory', 'entity_id', 'memory_hash', 'memory_passage_text_entities' ]) num_disallowed_per_device = [0 for _ in range(self.n_devices)] # Manually simulate retrieval per every query for query_id in range(self.n_devices * self.n_mentions): query = mention_encodings_stacked[query_id] top_memories_query = [] # Collect retirevals for a single query on each devices separately for device_id in range(self.n_devices): top_memories_per_device = [] for row_id in range(self.rows): scores = np.einsum('mh,h->m', memory_table[device_id, row_id], query) top_index = np.argmax(scores) devicewise_index = row_id * (self.table_size // self.rows) + top_index global_index = memory_ids[device_id, devicewise_index] self.assertEqual(global_index, devicewise_index + device_id * self.table_size) if max_text_identifiers is not None: memory_hash = memory_identifiers[device_id, devicewise_index].item() else: memory_hash = None if memory_text_entities is not None: memory_passage_text_entities = memory_text_entities[ device_id, devicewise_index] else: memory_passage_text_entities = None top_memories_per_device.append( RetrievedMemory( device=device_id, row=row_id, rowwise_index=top_index, devicewise_index=devicewise_index, global_index=global_index, score=scores[top_index].item(), memory=memory_table[device_id, row_id, top_index], entity_id=memory_entity_ids[device_id, devicewise_index].item(), memory_hash=memory_hash, memory_passage_text_entities=memory_passage_text_entities, )) # Sort by score. In case two scores are equal (likely because both # were considered "disallowed" we compare by entity IDs. top_memories_per_device.sort( key=lambda x: (x.score, x.entity_id), reverse=True) top_memories_per_device = top_memories_per_device[:self.k_top_device] top_memories_query.extend(top_memories_per_device) top_memories_query.sort( key=lambda x: (x.score, x.entity_id), reverse=True) if k_top_post_selection is not None: top_memories_query = top_memories_query[:k_top_post_selection] if max_text_identifiers is not None: num_current_disallowed = 0 text_id = text_identifiers[query_id // self.n_mentions, query_id % self.n_mentions].item() for i in range(len(top_memories_query)): if top_memories_query[i].memory_hash == text_id: num_current_disallowed += 1 if ((same_passage_memory_policy == 'disallow' and top_memories_query[i].memory_hash == text_id) or (same_passage_memory_policy == 'only' and top_memories_query[i].memory_hash != text_id)): top_memories_query[i] = top_memories_query[i]._replace( score=-_LARGE_NUMBER) num_disallowed_per_device[query_id // self.n_mentions] += num_current_disallowed top_memories_query.sort( key=lambda x: (x.score, x.global_index), reverse=True) actual_entity_ids = loss_helpers['top_entity_ids'][query_id // self.n_mentions, query_id % self.n_mentions] actual_memory_ids = loss_helpers['top_memory_ids'][query_id // self.n_mentions, query_id % self.n_mentions] actual_attention_weights = loss_helpers['memory_attention_weights'][ query_id // self.n_mentions, query_id % self.n_mentions] # We sort retrieved results first by the attention score and then # by memory ID. p = list(range(len(actual_attention_weights))) # pylint: disable=cell-var-from-loop p.sort( key=lambda i: (actual_attention_weights[i], actual_memory_ids[i]), reverse=True) # pylint: enable=cell-var-from-loop p = np.array(p) actual_entity_ids = list(actual_entity_ids[p]) actual_attention_weights = list(actual_attention_weights[p]) actual_memory_ids = list(actual_memory_ids[p]) expected_entity_ids = [x.entity_id for x in top_memories_query] self.assertSequenceEqual(expected_entity_ids, actual_entity_ids) expected_attention_weights = scipy.special.softmax( [x.score for x in top_memories_query]) self.assertSequenceAlmostEqual(expected_attention_weights, actual_attention_weights, 5) expected_memory_ids = [x.global_index for x in top_memories_query] self.assertSequenceEqual(expected_memory_ids, actual_memory_ids) actual_top_text_entities = loss_helpers['memory_top_text_entities'][ query_id // self.n_mentions, query_id % self.n_mentions] actual_top_text_entities = actual_top_text_entities[p] expected_top_text_entities = [ x.memory_passage_text_entities for x in top_memories_query ] self.assertEqual( len(actual_top_text_entities), len(expected_top_text_entities)) # Comparing `actual_top_text_entities` and `expected_top_text_entities` # directly is troublesome since we cannot gurantee the order for # retrieval results with the same attention weights. Therefore, we first # sort both `actual_top_text_entities` and `expected_attention_weights` # by the attention weight first and then by their elements. def sort_entities(top_text_entities): result = [(expected_attention_weights[i], list(top_text_entities[i])) for i in range(len(top_text_entities))] result.sort() return [x[1] for x in result] actual_top_text_entities = sort_entities(actual_top_text_entities) expected_top_text_entities = sort_entities(expected_top_text_entities) for i in range(len(actual_top_text_entities)): self.assertSequenceEqual( list(actual_top_text_entities[i]), list(expected_top_text_entities[i])) if max_text_identifiers is not None: self.assertSequenceEqual(num_disallowed_per_device, logging_helpers['n_disallowed'])
def test_memory_attention_backward(self): test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) pinit = jax.pmap( model.init, axis_name='batch', static_broadcasted_argnums=(9, 10)) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_table = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_table_sharded = jax.device_put_sharded(list(memory_table), devices) memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids initial_parameters = pinit( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_table_sharded, memory_identifiers, memory_entity_ids, True, # deterministic None, # memory_values text_identifiers=None, ) def step_fn( params, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys, memory_identifiers, memory_entity_ids, ): def loss_fn(params): encoded_output, _, _ = model.apply( {'params': params}, rngs=None, encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_keys, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, deterministic=True, text_identifiers=None, ) return encoded_output.sum() loss, grad = jax.value_and_grad(loss_fn)(params) return loss, grad pstep = jax.pmap(step_fn, axis_name='batch') _ = pstep( initial_parameters['params'], encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_table_sharded, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, )
def setUp(self): super().setUp() test_utils.force_multi_devices(self.n_devices)
def test_loss_fn( self, k_top, num_intermediate_layers=None, shared_initial_encoder=True, shared_intermediate_encoder=True, shared_final_encoder=True, no_retrieval=False, same_passage_retrieval_policy='allow', extract_unlinked_mentions=False, no_retrieval_for_masked_mentions=False, ): """Test loss function runs and produces expected values.""" config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.encoder_config) encoder_config['k_top'] = k_top encoder_config['num_intermediate_layers'] = num_intermediate_layers encoder_config['shared_initial_encoder'] = shared_initial_encoder encoder_config[ 'shared_intermediate_encoder'] = shared_intermediate_encoder encoder_config['shared_final_encoder'] = shared_final_encoder encoder_config['no_retrieval'] = no_retrieval encoder_config[ 'same_passage_retrieval_policy'] = same_passage_retrieval_policy encoder_config['extract_unlinked_mentions'] = extract_unlinked_mentions encoder_config[ 'no_retrieval_for_masked_mentions'] = no_retrieval_for_masked_mentions config['model_config']['encoder_config'] = encoder_config if no_retrieval: config['el_im_weight'] = 0 if num_intermediate_layers is not None: config['second_el_im_weight'] = 0.1 config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config preprocess_fn = readtwice_task.ReadTwiceTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = readtwice_task.ReadTwiceTask.make_collater_fn(config) postprocess_fn = readtwice_task.ReadTwiceTask.make_output_postprocess_fn( config) test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = readtwice_task.ReadTwiceTask.build_model(model_config) dummy_input = readtwice_task.ReadTwiceTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } batch = { key: jax.device_put_replicated(value, devices) for key, value in batch.items() } loss_fn = jax.pmap(readtwice_task.ReadTwiceTask.make_loss_fn(config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {}, # model vars batch, True, # deterministic ) take_first = lambda x: x[0] metrics = jax.tree_map(take_first, metrics) np_batch = jax.tree_map(take_first, batch) # mlm losses expected_mlm_denom = np_batch['mlm_target_weights'].sum() expected_mlm_mention_denom = (np_batch['mlm_target_weights'] * np_batch['mlm_target_is_mention']).sum() expected_mlm_non_mention_denom = ( np_batch['mlm_target_weights'] * (1 - np_batch['mlm_target_is_mention'])).sum() self.assertEqual(metrics['mlm']['denominator'], expected_mlm_denom) self.assertEqual(metrics['mlm_mention']['denominator'], expected_mlm_mention_denom) self.assertEqual(metrics['mlm_non_mention']['denominator'], expected_mlm_non_mention_denom) self.assertEqual(metrics['mlm_first']['denominator'], expected_mlm_denom) self.assertEqual(metrics['mlm_mention_first']['denominator'], expected_mlm_mention_denom) self.assertEqual(metrics['mlm_non_mention_first']['denominator'], expected_mlm_non_mention_denom) # same entity retrieval loss if not no_retrieval: expected_same_entity_denom = np_batch[ 'mention_target_weights'].sum() self.assertEqual(metrics['el_intermediate']['denominator'], expected_same_entity_denom) if num_intermediate_layers is not None: self.assertEqual( metrics['second_el_intermediate']['denominator'], expected_same_entity_denom) # coref losses expected_coref_denom = np_batch['mention_target_weights'].sum() expected_coref_masked_denom = ( np_batch['mention_target_weights'] * np_batch['mention_target_is_masked']).sum() expected_coref_non_masked_denom = ( np_batch['mention_target_weights'] * (1 - np_batch['mention_target_is_masked'])).sum() for coref_type in {'key', 'value', 'final'}: self.assertEqual( metrics[coref_type + '_coref_resolution']['denominator'], expected_coref_denom) self.assertEqual( metrics[coref_type + '_coref_resolution_masked']['denominator'], expected_coref_masked_denom) self.assertEqual( metrics[coref_type + '_coref_resolution_non_masked']['denominator'], expected_coref_non_masked_denom) # mtb losses for mtb_type in {'key', 'value', 'final'}: self.assertIn(mtb_type + '_mtb', metrics) self.assertIn(mtb_type + '_mtb_masked', metrics) self.assertIn(mtb_type + '_mtb_non_masked', metrics) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key])