def make_initial_state(key): """""" # policy stuff key, sub_key = jax.random.split(key) policy_params = networks.policy_network.init(sub_key) policy_optimizer_state = policy_optimizer.init(policy_params) devices = jax.local_devices() replicated_policy_params = jax.device_put_replicated( policy_params, devices) replicated_optim_state = jax.device_put_replicated( policy_optimizer_state, devices) if use_img_encoder: """ Load pretrained img_encoder_params and do: replicated_img_encoder_params = jax.device_put_replicated( img_encoder_params, devices) """ class EncoderTrainingState(NamedTuple): encoder_params: hk.Params img_encoder_params = {} replicated_img_encoder_params = img_encoder_params raise NotImplementedError('Need to load a checkpoint.') else: img_encoder_params = {} replicated_img_encoder_params = img_encoder_params state = TrainingState( policy_optimizer_state=replicated_optim_state, policy_params=replicated_policy_params, key=key, img_encoder_params=replicated_img_encoder_params) return state
def make_initial_state(key): """""" num_devices = jax.device_count() # critic stuff # model params key, sub_key = jax.random.split(key) shared_params, ensemble_params = networks.q_ensemble_init( ensemble_size, sub_key) # replicated_shared_params = jax.tree_map( # lambda x: jnp.array([x] * num_devices), shared_params) replicated_shared_params = jax.device_put_replicated( shared_params, jax.local_devices()) # optim params _, shared_params_optim_state, ensemble_params_optim_state = ensemble_utils.build_ensemble_optimizer( ensemble_size, shared_params, ensemble_params, optax.adam, {'learning_rate': q_lr}) # replicated_shared_params_optim_state = jax.tree_map( # lambda x: jnp.array([x] * num_devices), shared_params_optim_state) replicated_shared_params_optim_state = jax.device_put_replicated( shared_params_optim_state, jax.local_devices()) # policy stuff key, sub_key = jax.random.split(key) policy_params = networks.policy_network.init(sub_key) policy_optimizer_state = policy_optimizer.init(policy_params) # replicated_policy_params = jax.tree_map( # lambda x: jnp.array([x] * num_devices), policy_params) # replicated_policy_optimizer_state = jax.tree_map( # lambda x: jnp.array([x] * num_devices), policy_optimizer_state) replicated_policy_params = jax.device_put_replicated( policy_params, jax.local_devices()) replicated_policy_optimizer_state = jax.device_put_replicated( policy_optimizer_state, jax.local_devices()) state = TrainingState( replicated_policy_optimizer_state= replicated_policy_optimizer_state, replicated_shared_q_optim_state= replicated_shared_params_optim_state, ensemble_q_optim_state=ensemble_params_optim_state, replicated_policy_params=replicated_policy_params, replicated_shared_q_params=replicated_shared_params, ensemble_q_params=ensemble_params, target_replicated_shared_q_params=replicated_shared_params, target_ensemble_q_params=ensemble_params, key=key, ) # entropy stuff if adaptive_entropy_coefficient: state = state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha) # jax.tree_map(lambda t: print(t.shape), replicated_shared_params_optim_state) return state
def replicate(self): """A context manager to use in a with statement that replicates the variables in this collection to multiple devices. Important: replicating also updates the random state in order to have a new one per device. """ global math if math is None: from brainpy import math replicated, saved_states = {}, {} x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_) sharded_x = jax.pmap(lambda x: x, axis_name='device')(x) devices = [b.device() for b in sharded_x.device_buffers] num_device = len(devices) for k, d in self.items(): if isinstance(d, math.random.RandomState): replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices) saved_states[k] = d.value else: replicated[k] = jax.device_put_replicated(d.value, devices) self.assign(replicated) yield visited = set() for k, d in self.items(): # Careful not to reduce twice in case of # a variable and a reference to it. if id(d) not in visited: if isinstance(d, math.random.RandomState): d.value = saved_states[k] else: d.value = reduce_func(d) visited.add(id(d))
def f(x): if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): return jax.device_put_replicated(x, jax.local_devices()) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) else: return x
def test_pmap_update_nested(self): local_device_count = jax.local_device_count() state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x = { 'a': (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 5), 'b': (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 2), } devices = jax.local_devices() state = jax.device_put_replicated(state, devices) pmap_axis_name = 'i' state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) normalized = jax.pmap(running_statistics.normalize)(x, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def replicate(self): """A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device. """ replicated, saved_states = [], [] devices = get_local_devices() ndevices = len(devices) for v in self: if isinstance(v, RandomState): replicated.append(jax.device_put_sharded([shard for shard in v.split(ndevices)], devices)) saved_states.append(v.value) else: replicated.append(jax.device_put_replicated(v.value, devices)) self.assign(replicated) yield visited = set() saved_states.reverse() for k, v in self.items(): if isinstance(v, TrainRef): v = v.ref assert not isinstance(v, TrainRef) if id(v) not in visited: # Careful not to reduce twice in case of a variable and a reference to it. if isinstance(v, RandomState): v.assign(saved_states.pop()) else: v.reduce(v.value) visited.add(id(v))
def opt_state(state, optimizer): new_grad_acc, new_opt_state, new_params = opt_jit(state["grad_acc"], state["opt_state"], state["params"], optimizer) state["grad_acc"] = new_grad_acc state["opt_state"] = new_opt_state state["params"] = jax.device_put_replicated(new_params, jax.local_devices()) state["grad_count"] = np.array(0) return state
def make_initial_state(key): """Initialises the training state (parameters and optimiser state).""" key_policy, key_q, key = jax.random.split(key, 3) devices = jax.local_devices() policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) policy_params = jax.device_put_replicated(policy_params, devices) policy_optimizer_state = jax.device_put_replicated( policy_optimizer_state, devices) q_params = networks.q_network.init(key_q) q_optimizer_state = q_optimizer.init(q_params) q_params = jax.device_put_replicated(q_params, devices) q_optimizer_state = jax.device_put_replicated( q_optimizer_state, devices) key, sub_key = jax.random.split(key) c_dim = 42 # TODO(kamyar): implement this snr_state = snr_utils.snr_state_init( c_dim, sub_key, snr_kwargs, ) snr_state = jax.device_put_replicated(snr_state, devices) state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=q_params, key=key, snr_state=snr_state) if adaptive_entropy_coefficient: state = state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha) return state
def test_build_coref_positive_negative_mask(self): all_mention_target_ids = jax.device_put_replicated( self.mention_target_ids_stacked, self.devices) get_batch_positions = functools.partial( mention_utils.get_globally_consistent_batch_positions, batch_size=self.batch_size) get_batch_positions = jax.pmap(get_batch_positions, axis_name='batch') (local_mention_batch_positions, global_mention_batch_positions) = get_batch_positions( self.mention_batch_positions_sharded) (positive_mask, negative_mask) = jax.pmap( mention_losses.build_coref_positive_negative_mask, axis_name='batch')(local_mention_batch_positions, global_mention_batch_positions, self.mention_target_ids_sharded, all_mention_target_ids) n_all_mentions = self.n_mentions * self.n_devices self.assertSequenceEqual(positive_mask.shape, negative_mask.shape) self.assertSequenceEqual( positive_mask.shape, (self.n_devices, self.n_mentions, n_all_mentions)) positive_mask = positive_mask.reshape(-1, n_all_mentions) negative_mask = negative_mask.reshape(-1, n_all_mentions) for i in range(n_all_mentions): for j in range(n_all_mentions): is_same_device = i // self.n_mentions == j // self.n_mentions is_same_passage = (self.mention_batch_positions_stacked[i] == self.mention_batch_positions_stacked[j]) is_same_passage = is_same_passage and is_same_device if (self.mention_target_ids_stacked[i] == 0 or self.mention_target_ids_stacked[j] == 0 or is_same_passage): self.assertEqual(positive_mask[i, j], 0) self.assertEqual(negative_mask[i, j], 0) continue self.assertEqual( positive_mask[i, j], self.mention_target_ids_stacked[i] == self.mention_target_ids_stacked[j]) self.assertEqual( negative_mask[i, j], self.mention_target_ids_stacked[i] != self.mention_target_ids_stacked[j])
def tree_replicate_by_name( param_tree, filter_fn, devices = None): """Replicates leaf arrays whose name is matched by a filter. Args: param_tree: Tree of parameters. filter_fn: Leaf node filter function. devices: XLA devices. Returns: A tree identical in structure to `param_tree` except that those leaves which satisfy `filter_fn` are replicated across devices. """ devices = devices or jax.local_devices() return tree_map_with_names(lambda x: jax.device_put_replicated(x, devices), param_tree, filter_fn)
def main(_): assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 save_dir = FLAGS.model_dir if FLAGS.save_dir is None else FLAGS.save_dir logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine = jax.random.split(rng, 3) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets train_items, val_items, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) step = int(state.step) state = jax.device_put_replicated(state, jax.local_devices()) # TODO: TPU Colab breaks without message if this is a list # a list is preferred bc tqdm can show an ETA render_dict = { "train": zip(range(train_items), train_ds), "val": zip(range(val_items), val_ds), "test": zip(range(test_items), test_ds), "poses": zip(range(num_poses), render_ds), } render_poses = render_dict[FLAGS.render_video_set] def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) if FLAGS.render_video: rgb_list = [] disp_list = [] losses = [] for _, inputs in tqdm(render_poses, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render": loss = np.mean((preds["rgb"] - inputs["image"])**2.0) losses.append(loss) if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render": loss = np.mean(losses) logging.info("Loss %.5f", loss) logging.info("PSNR %.5f", psnr_fn(loss)) gen_video(save_dir, np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video(save_dir, disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.render_testset: test_losses = [] for idx, inputs in tqdm(zip(range(test_items), test_ds), desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) save_test_imgs(save_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 1: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 1: loss = np.mean(test_losses) logging.info("Loss %.5f", loss) logging.info("PSNR %.5f", psnr_fn(loss))
def run_model(self, config, entity_vocab_size): """Initialize and run the model once, perform sanity checks.""" np.random.seed(0) # Save arrays to test retrieval saver. memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, self.devices) memory_entity_ids = memory_identifiers config['memory_entity_id_pattern'] = self.save_sharded_array( memory_entity_ids, 'memory_entity_id') memory_text = np.random.randint( config['model_config']['encoder_config']['vocab_size'], size=(self.n_devices, self.table_size, self.memory_text_length), dtype=np.int32) config['memory_text_pattern'] = self.save_sharded_array( memory_text, 'memory_text') memory_positions = np.random.randint(self.memory_text_length, size=(self.n_devices, self.table_size, 2), dtype=np.int32) config['memory_positions_pattern'] = self.save_sharded_array( memory_positions, 'memory_positions') config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_collater_fn( config) postprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_output_postprocess_fn( config) model = mention_based_entity_qa_task.MentionBasedEntityQATask.build_model( model_config) dummy_input = mention_based_entity_qa_task.MentionBasedEntityQATask.dummy_input( config) dummy_input = jax.device_put_replicated(dummy_input, self.devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, self.devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, self.devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'encoder': { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, } } def sample_batch(): processed_examples = [] for _ in range(config.per_device_batch_size): raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, entity_vocab_size=entity_vocab_size, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) processed_examples.append(processed_example) batch = stack(processed_examples) batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } return batch batch = stack([sample_batch() for _ in range(self.n_devices)]) batch = { key: jax.device_put_sharded(list(value), self.devices) for key, value in batch.items() } loss_fn = jax.pmap( mention_based_entity_qa_task.MentionBasedEntityQATask.make_loss_fn( config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {'constants': initial_variables['constants']}, batch, True, ) self.assertArrayEqual(metrics['agg']['denominator'], batch['mention_target_weights'].sum(1)) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key]) n_mentions_per_device = (config.per_device_batch_size * config.max_mention_targets) k_top_final = (encoder_config.final_k_top_post_selection or encoder_config.final_k_top_device * self.n_devices) self.assertSequenceEqual( np.array(features['memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_final, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_final, 2]) return batch, initial_variables, metrics
from PIL import Image import jax import time import clip_jax image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load( 'ViT-B/32', "cpu") batch_size = 2048 devices = jax.local_devices() print(f"jax devices: {devices}") jax_params = jax.device_put_replicated(jax_params, devices) image_fn = jax.pmap(image_fn) text_fn = jax.pmap(text_fn) jax_image = np.expand_dims(jax_preprocess(Image.open("CLIP.png")), (0, 1)) jax_image = np.repeat(jax_image, len(devices), axis=0) jax_image = np.repeat(jax_image, batch_size, axis=1) jax_text = np.expand_dims(clip_jax.tokenize(["a diagram"]), 0) jax_text = np.repeat(jax_text, len(devices), axis=0) jax_text = np.repeat(jax_text, batch_size, axis=1) start = time.time() jax_image_embed = image_fn(jax_params, jax_image) jax_text_embed = text_fn(jax_params, jax_text) total = time.time() - start
def run_model(self, config, entity_vocab_size): """Initialize and run the model once, perform sanity checks.""" np.random.seed(0) # Save arrays to test retrieval saver. memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, self.devices) memory_entity_ids = memory_identifiers config['memory_entity_id_pattern'] = self.save_sharded_array( memory_entity_ids, 'memory_entity_id') memory_text = np.random.randint( config['model_config']['encoder_config']['vocab_size'], size=(self.n_devices, self.table_size, self.memory_text_length), dtype=np.int32) config['memory_text_pattern'] = self.save_sharded_array( memory_text, 'memory_text') memory_positions = np.random.randint(self.memory_text_length, size=(self.n_devices, self.table_size, 2), dtype=np.int32) config['memory_positions_pattern'] = self.save_sharded_array( memory_positions, 'memory_positions') config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn( config) postprocess_fn = mention_memory_task.MentionMemoryTask.make_output_postprocess_fn( config) model = mention_memory_task.MentionMemoryTask.build_model(model_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, self.devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, self.devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, self.devices) # `memory_text_entities` are assumed to contain unique IDs in the last dim. memory_text_entities = np.zeros( (self.n_devices, self.table_size, encoder_config.n_memory_text_entities), np.int32) for device_index in range(self.n_devices): for t_index in range(self.table_size): current_text_entities = np.random.choice( entity_vocab_size, size=(min(encoder_config.n_memory_text_entities, entity_vocab_size)), replace=False) memory_text_entities[device_index, t_index, :len(current_text_entities )] = current_text_entities memory_text_entities = jax.device_put_sharded( list(memory_text_entities), self.devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'encoder': { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } } def sample_batch(): processed_examples = [] for _ in range(config.per_device_batch_size): raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, entity_vocab_size=entity_vocab_size, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) processed_examples.append(processed_example) batch = stack(processed_examples) batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } text_ids = batch['text_ids'] for i in range(config.per_device_batch_size): for j in range(config.max_mlm_targets): if batch['mlm_target_weights'][i, j] > 0: text_ids[i, batch['mlm_target_positions'][ i, j]] = batch['mlm_target_ids'][i, j] mention_batch_positions = batch['mention_batch_positions'] text_identifiers = batch['text_identifiers'].astype( np.int32).tolist() expected_text_identifiers = [ mention_preprocess_utils.text_hash( text_ids[mention_batch_positions[index]]).astype(np.int32) for index in range(len(mention_batch_positions)) ] self.assertSequenceEqual(text_identifiers, expected_text_identifiers) return batch batch = stack([sample_batch() for _ in range(self.n_devices)]) batch = { key: jax.device_put_sharded(list(value), self.devices) for key, value in batch.items() } loss_fn = jax.pmap( mention_memory_task.MentionMemoryTask.make_loss_fn(config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {'constants': initial_variables['constants']}, batch, True, ) metrics_per_first_device = jax.tree_map(lambda x: x[0], metrics) self.assertEqual(metrics_per_first_device['mlm']['denominator'], batch['mlm_target_weights'][0].sum()) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key]) n_mentions_per_device = (config.per_device_batch_size * config.max_mentions) if config.save_k_retrieval is not None: k_top_saved = min(config.save_k_retrieval, encoder_config.k_top_post_selection) else: k_top_saved = (encoder_config.k_top_post_selection or encoder_config.k_top_device * self.n_devices) self.assertSequenceEqual( np.array(features['memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_saved, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_saved, 2]) if encoder_config.get('num_intermediate_layers') is not None: self.assertSequenceEqual( np.array(features['second_memory_text']).shape, [ self.n_devices, n_mentions_per_device, k_top_saved, self.memory_text_length ]) self.assertSequenceEqual( np.array(features['second_memory_positions']).shape, [self.n_devices, n_mentions_per_device, k_top_saved, 2]) return batch, initial_variables, metrics
def load_weights(config: ml_collections.ConfigDict) -> Dict[str, Any]: """Load model weights from file.""" params = checkpoint_utils.load_weights(config.load_weights) params = jax.device_put_replicated(params, jax.local_devices()) return {'params': params}
def test_model_shape( self, separate_memory_values=False, num_intermediate_layers=None, ): """Test loss function runs and produces expected values.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config['model_config']['encoder_config'][ 'num_intermediate_layers'] = num_intermediate_layers config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn( config) test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } batch = { key: jax.device_put_replicated(value, devices) for key, value in batch.items() } def model_apply(*args, **kwargs): return model.apply(*args, method=model.forward, **kwargs) papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2)) encoded_output, loss_helpers, _ = papply( { 'params': initial_variables['params'], 'constants': initial_variables['constants'], }, batch, True, ) self.assertEqual( encoded_output.shape, (self.n_devices, config.per_device_batch_size, encoder_config.max_length, encoder_config.hidden_size)) memory_value_dim = encoder_config.memory_value_dim memory_key_dim = encoder_config.memory_key_dim memory_size = memory_value_dim if memory_value_dim else memory_key_dim self.assertEqual(loss_helpers['target_mention_encodings'].shape, (self.n_devices, config.max_mention_targets * config.per_device_batch_size, memory_size))
def replicate_model_state(model_states: TrainState) -> TrainState: """Replicates the model states.""" return jax.device_put_replicated(model_states, jax.local_devices())
def test_load_weights(self, separate_memory_values=False, memory_only=False): """Test saving and loading model recovers original parameters.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'separate_memory_values'] = separate_memory_values config = ml_collections.ConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config rows = encoder_config.rows test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config) dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) memory_table = np.random.rand(rows, self.table_size // rows, encoder_config.memory_key_dim) memory_keys = jax.device_put_replicated(memory_table, devices) memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim) memory_values = jax.device_put_replicated(memory_values, devices) memory_identifiers = np.arange(self.table_size) memory_identifiers = jax.device_put_replicated(memory_identifiers, devices) memory_entity_ids = memory_identifiers memory_text_entities = np.zeros( (self.table_size, encoder_config.n_memory_text_entities), dtype=np.int32) memory_text_entities = jax.device_put_replicated( memory_text_entities, devices) def model_init(*args, **kwargs): return model.init(*args, method=model.forward, **kwargs) initial_variables = jax.pmap(model_init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) initial_variables = {'params': initial_variables['params']} initial_variables['constants'] = { 'memory_keys': memory_keys, 'memory_values': memory_values, 'memory_identifiers': memory_identifiers, 'memory_entity_ids': memory_entity_ids, 'memory_text_entities': memory_text_entities, } n_shards = 4 tempdir_obj = self.create_tempdir() tempdir = tempdir_obj.full_path memory_key_base = os.path.join(tempdir, 'memory_keys') memory_value_base = os.path.join(tempdir, 'memory_values') memory_id_base = os.path.join(tempdir, 'memory_id') memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id') memory_text_entities_base = os.path.join(tempdir, 'memory_text_entities') unreplicated_variables = jax_utils.unreplicate(initial_variables) unreplicated_variables['params'] = unreplicated_variables[ 'params'].unfreeze() if memory_only: load_weights = 'memory_only' else: load_weights = os.path.join(tempdir, 'weights') checkpoint_utils.save_weights(load_weights, unreplicated_variables['params']) memory_keys = initial_variables['constants']['memory_keys'] memory_keys = memory_keys.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_values = initial_variables['constants']['memory_values'] memory_values = memory_values.reshape(n_shards, -1, encoder_config.memory_key_dim) memory_ids = initial_variables['constants'][ 'memory_identifiers'].reshape(n_shards, -1) memory_entity_ids = initial_variables['constants'][ 'memory_entity_ids'].reshape(n_shards, -1) memory_text_entities = initial_variables['constants'][ 'memory_text_entities'].reshape( n_shards, -1, encoder_config.n_memory_text_entities) for shard in range(n_shards): np.save(memory_key_base + str(shard), memory_keys[shard]) np.save(memory_value_base + str(shard), memory_values[shard]) np.save(memory_id_base + str(shard), memory_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_entity_id_base + str(shard), memory_entity_ids[shard]) np.save(memory_text_entities_base + str(shard), memory_text_entities[shard]) config.memory_key_pattern = memory_key_base + '*' config.memory_value_pattern = memory_value_base + '*' config.memory_id_pattern = memory_id_base + '*' config.memory_entity_id_pattern = memory_entity_id_base + '*' config.memory_text_entities_pattern = memory_text_entities_base + '*' config.load_weights = load_weights loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights( config) arrayeq = lambda x, y: jnp.all(x == y) constants = { key: value for key, value in initial_variables['constants'].items() if not (key == 'memory_values' and not separate_memory_values) } comparison_variables = {'constants': constants} if not memory_only: comparison_variables['params'] = initial_variables[ 'params'].unfreeze() self.assertTrue( jax.tree_map(arrayeq, loaded_variables, comparison_variables))
def main(_): if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching: raise ValueError( "'precrop_iters has no effect when 'batching' the dataset") assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") os.makedirs(FLAGS.model_dir, exist_ok=True) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5) rngs = common_utils.shard_prng_key(step_rng) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, rng=data_rng, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets *_, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets iter_render_ds = zip(range(num_poses), render_ds) iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds) iter_test_ds = zip(range(test_items), test_ds) img_h, img_w, _ = hwf logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) # cycle already seen examples if resuming from checkpoint # (only useful for ensuring deterministic dataset, slow for large start_step) # if start_step != 0: # for _ in range(start_step): # _ = next(train_ds) # parameter_overview.log_parameter_overview(state.optimizer_coarse.target) # if FLAGS.config.num_importance > 0: # parameter_overview.log_parameter_overview(state.optimizer_fine.target) state = jax.device_put_replicated(state, jax.local_devices()) ### Build "pmapped" functions for distributed training train_fn = functools.partial(train_step, near, far, FLAGS.config, schedule_fn) p_train_step = jax.pmap( train_fn, axis_name="batch", in_axes=(0, 0, None, 0), # donate_argnums=(0, 1, 2), ) def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) # TODO: add hparams writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) logging.info("Starting training loop.") hooks = [] profiler = periodic_actions.Profile(num_profile_steps=5, logdir=FLAGS.model_dir) report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.config.num_steps, writer=writer) if jax.process_index() == 0: hooks += [profiler, report_progress] train_metrics = [] gen_video_ = functools.partial(gen_video, FLAGS.model_dir) for step in range(start_step, FLAGS.config.num_steps + 1): is_last_step = step == FLAGS.config.num_steps batch = next(train_ds) coords = None if not FLAGS.config.batching: coords = jnp.meshgrid(jnp.arange(img_h), jnp.arange(img_w), indexing="ij") if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jnp.stack(coords, axis=-1).reshape([-1, 2]) with jax.profiler.StepTraceAnnotation("train", step_num=step): state, metrics = p_train_step(batch, state, coords, rngs) train_metrics.append(metrics) logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) _ = [h(step) for h in hooks] ### Write train summaries to TB if step % FLAGS.config.i_print == 0 or is_last_step: with report_progress.timed("training_metrics"): train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) summary = {f"train/{k}": v for k, v in train_summary.items()} writer.write_scalars(step, summary) train_metrics = [] ### Eval a random validation image and plot it to TB if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step: with report_progress.timed("validation"): inputs = next(val_ds) rays, padding = prepare_render_data(inputs["rays"]._numpy()) outputs = p_eval_step(state, rays) preds, preds_c, z_std = jax.tree_map( lambda x: to_np(x, hwf, padding), outputs) loss = np.mean((preds["rgb"] - inputs["image"])**2) summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) summary = { "val/rgb": to_rgb(preds["rgb"]), "val/target": to_np(inputs["image"], hwf, padding), "val/disp": disp_post(preds["disp"], FLAGS.config), "val/acc": preds["acc"], } if FLAGS.config.num_importance > 0: summary["val/rgb_c"] = to_rgb(preds_c["rgb"]) summary["val/disp_c"] = disp_post(preds_c["disp"], FLAGS.config) summary["val/z_std"] = z_std writer.write_images(step, summary) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: with report_progress.timed("video_render"): logging.info("Rendering video at step %d", step) rgb_list = [] disp_list = [] for idx, inputs in tqdm(iter_render_ds, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) gen_video_(np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video_(disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: rgb_list = [] for idx, inputs in tqdm(iter_vdirs_ds, desc="Viewdirs render"): rays, padding = prepare_render_data( inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) rgb_list.append(to_np(preds["rgb"], r_hwf, padding)) gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: with report_progress.timed("test_render"): logging.info("Rendering test set at step %d", step) test_losses = [] for idx, inputs in tqdm(iter_test_ds, desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 0: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 0: loss = np.mean(test_losses) summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) writer.flush() ### Save ckpt if step % FLAGS.config.i_weights == 0 or is_last_step: with report_progress.timed("checkpoint"): save_checkpoint(state, FLAGS.model_dir)
def test_loss_fn( self, k_top, num_intermediate_layers=None, shared_initial_encoder=True, shared_intermediate_encoder=True, shared_final_encoder=True, no_retrieval=False, same_passage_retrieval_policy='allow', extract_unlinked_mentions=False, no_retrieval_for_masked_mentions=False, ): """Test loss function runs and produces expected values.""" config = copy.deepcopy(self.config) encoder_config = copy.deepcopy(self.encoder_config) encoder_config['k_top'] = k_top encoder_config['num_intermediate_layers'] = num_intermediate_layers encoder_config['shared_initial_encoder'] = shared_initial_encoder encoder_config[ 'shared_intermediate_encoder'] = shared_intermediate_encoder encoder_config['shared_final_encoder'] = shared_final_encoder encoder_config['no_retrieval'] = no_retrieval encoder_config[ 'same_passage_retrieval_policy'] = same_passage_retrieval_policy encoder_config['extract_unlinked_mentions'] = extract_unlinked_mentions encoder_config[ 'no_retrieval_for_masked_mentions'] = no_retrieval_for_masked_mentions config['model_config']['encoder_config'] = encoder_config if no_retrieval: config['el_im_weight'] = 0 if num_intermediate_layers is not None: config['second_el_im_weight'] = 0.1 config = ml_collections.FrozenConfigDict(config) model_config = config.model_config encoder_config = model_config.encoder_config preprocess_fn = readtwice_task.ReadTwiceTask.make_preprocess_fn(config) # pylint: disable=line-too-long collater_fn = readtwice_task.ReadTwiceTask.make_collater_fn(config) postprocess_fn = readtwice_task.ReadTwiceTask.make_output_postprocess_fn( config) test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = readtwice_task.ReadTwiceTask.build_model(model_config) dummy_input = readtwice_task.ReadTwiceTask.dummy_input(config) dummy_input = jax.device_put_replicated(dummy_input, devices) init_rng = jax.random.PRNGKey(0) split_rng = jax.random.split(init_rng, self.n_devices) initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)( split_rng, dummy_input, True, ) raw_example = test_utils.gen_mention_pretraining_sample( self.text_length, self.n_mentions, self.n_linked_mentions, max_length=encoder_config.max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = { key: test_utils.tensor_to_numpy(value) for key, value in batch.items() } batch = { key: jax.device_put_replicated(value, devices) for key, value in batch.items() } loss_fn = jax.pmap(readtwice_task.ReadTwiceTask.make_loss_fn(config), 'batch', static_broadcasted_argnums=(0, 4)) _, metrics, auxiliary_output = loss_fn( model_config, initial_variables['params'], {}, # model vars batch, True, # deterministic ) take_first = lambda x: x[0] metrics = jax.tree_map(take_first, metrics) np_batch = jax.tree_map(take_first, batch) # mlm losses expected_mlm_denom = np_batch['mlm_target_weights'].sum() expected_mlm_mention_denom = (np_batch['mlm_target_weights'] * np_batch['mlm_target_is_mention']).sum() expected_mlm_non_mention_denom = ( np_batch['mlm_target_weights'] * (1 - np_batch['mlm_target_is_mention'])).sum() self.assertEqual(metrics['mlm']['denominator'], expected_mlm_denom) self.assertEqual(metrics['mlm_mention']['denominator'], expected_mlm_mention_denom) self.assertEqual(metrics['mlm_non_mention']['denominator'], expected_mlm_non_mention_denom) self.assertEqual(metrics['mlm_first']['denominator'], expected_mlm_denom) self.assertEqual(metrics['mlm_mention_first']['denominator'], expected_mlm_mention_denom) self.assertEqual(metrics['mlm_non_mention_first']['denominator'], expected_mlm_non_mention_denom) # same entity retrieval loss if not no_retrieval: expected_same_entity_denom = np_batch[ 'mention_target_weights'].sum() self.assertEqual(metrics['el_intermediate']['denominator'], expected_same_entity_denom) if num_intermediate_layers is not None: self.assertEqual( metrics['second_el_intermediate']['denominator'], expected_same_entity_denom) # coref losses expected_coref_denom = np_batch['mention_target_weights'].sum() expected_coref_masked_denom = ( np_batch['mention_target_weights'] * np_batch['mention_target_is_masked']).sum() expected_coref_non_masked_denom = ( np_batch['mention_target_weights'] * (1 - np_batch['mention_target_is_masked'])).sum() for coref_type in {'key', 'value', 'final'}: self.assertEqual( metrics[coref_type + '_coref_resolution']['denominator'], expected_coref_denom) self.assertEqual( metrics[coref_type + '_coref_resolution_masked']['denominator'], expected_coref_masked_denom) self.assertEqual( metrics[coref_type + '_coref_resolution_non_masked']['denominator'], expected_coref_non_masked_denom) # mtb losses for mtb_type in {'key', 'value', 'final'}: self.assertIn(mtb_type + '_mtb', metrics) self.assertIn(mtb_type + '_mtb_masked', metrics) self.assertIn(mtb_type + '_mtb_non_masked', metrics) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key])
def test_mention_memory_layer(self, separate_memory_values): """Testing memory attention layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) static_argnums = (9) if separate_memory_values else (9, 10) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=static_argnums) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) n_mentions = mention_start_positions.shape[-1] mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) # Make sure id 0 or 1 will be highest scoring memory_table[0] = memory_table[0] * 2.0 memory_table[1] = memory_table[1] * -2.0 memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_keys = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices) if separate_memory_values: memory_values = memory_table.reshape(self.n_devices, self.table_size, self.memory_key_dim) memory_values = jax.device_put_sharded(list(memory_values), devices) else: memory_values = None memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids (encoded_output, loss_helpers, _), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, text_identifiers=None, ) attention_weights = loss_helpers['memory_attention_weights'] entity_ids = loss_helpers['top_entity_ids'] normed_input = encoded_input - 1.0 # Check input was changed self.assertFalse(jnp.allclose(encoded_output, normed_input)) # Check input was not changed where it should not be all_indices = set( itertools.product(np.arange(self.bsz), np.arange(self.seq_len))) # Note that mention positions is the same across all of the devices start_indices = set( zip(mention_batch_positions[0].tolist(), mention_start_positions[0].tolist())) non_start_indices = all_indices.difference(start_indices) non_start_indices_1, non_start_indices_2 = zip(*non_start_indices) non_start_indices_1 = jnp.asarray(non_start_indices_1) non_start_indices_2 = jnp.asarray(non_start_indices_2) non_start_outputs = encoded_output[:, non_start_indices_1, non_start_indices_2] non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2] self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs)) # Check shapes as expected self.assertSequenceEqual( encoded_output.shape, (self.n_devices, self.bsz, self.seq_len, self.input_dim)) self.assertSequenceEqual( attention_weights.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) self.assertSequenceEqual( entity_ids.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) # Check id 0 or 1 retrieved self.assertTrue( jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1))) # Set some text identifiers to 0 and others to 1 so that some are binding text_identifiers = np.zeros((n_mentions), dtype=np.int32) text_identifiers[:n_mentions // 2] = 1 text_identifiers = jax.device_put_replicated(text_identifiers, devices) # Initialize and run one forward pass of model (_, loss_helpers, logging_helpers), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, # memory_values text_identifiers=text_identifiers, ) attention_weights_wid = loss_helpers['memory_attention_weights'] entity_ids_wid = loss_helpers['top_entity_ids'] n_disallowed = logging_helpers['n_disallowed'][0] # Check no effect on ids self.assertTrue(jnp.all(entity_ids == entity_ids_wid)) # Check id 0 or 1 have 0 scores text_identifiers = jnp.expand_dims(text_identifiers, -1) score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid self.assertAlmostEqual(score_masked.sum(), 0.0) # Check number disallowed as expected self.assertEqual(n_disallowed, n_mentions // 2)
def test_memory_attention_backward(self): test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) pinit = jax.pmap( model.init, axis_name='batch', static_broadcasted_argnums=(9, 10)) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_table = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_table_sharded = jax.device_put_sharded(list(memory_table), devices) memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids initial_parameters = pinit( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_table_sharded, memory_identifiers, memory_entity_ids, True, # deterministic None, # memory_values text_identifiers=None, ) def step_fn( params, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys, memory_identifiers, memory_entity_ids, ): def loss_fn(params): encoded_output, _, _ = model.apply( {'params': params}, rngs=None, encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_keys, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, deterministic=True, text_identifiers=None, ) return encoded_output.sum() loss, grad = jax.value_and_grad(loss_fn)(params) return loss, grad pstep = jax.pmap(step_fn, axis_name='batch') _ = pstep( initial_parameters['params'], encoded_input=encoded_input, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, memory_keys=memory_table_sharded, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, )