示例#1
0
    def make_initial_state(key):
      """"""
      # policy stuff
      key, sub_key = jax.random.split(key)
      policy_params = networks.policy_network.init(sub_key)
      policy_optimizer_state = policy_optimizer.init(policy_params)

      devices = jax.local_devices()
      replicated_policy_params = jax.device_put_replicated(
          policy_params, devices)
      replicated_optim_state = jax.device_put_replicated(
          policy_optimizer_state, devices)

      if use_img_encoder:
        """
        Load pretrained img_encoder_params and do:
        replicated_img_encoder_params = jax.device_put_replicated(
            img_encoder_params, devices)
        """
        class EncoderTrainingState(NamedTuple):
          encoder_params: hk.Params
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params
        raise NotImplementedError('Need to load a checkpoint.')
      else:
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params

      state = TrainingState(
          policy_optimizer_state=replicated_optim_state,
          policy_params=replicated_policy_params,
          key=key,
          img_encoder_params=replicated_img_encoder_params)
      return state
示例#2
0
        def make_initial_state(key):
            """"""
            num_devices = jax.device_count()
            # critic stuff
            # model params
            key, sub_key = jax.random.split(key)
            shared_params, ensemble_params = networks.q_ensemble_init(
                ensemble_size, sub_key)
            # replicated_shared_params = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), shared_params)
            replicated_shared_params = jax.device_put_replicated(
                shared_params, jax.local_devices())

            # optim params
            _, shared_params_optim_state, ensemble_params_optim_state = ensemble_utils.build_ensemble_optimizer(
                ensemble_size, shared_params, ensemble_params, optax.adam,
                {'learning_rate': q_lr})
            # replicated_shared_params_optim_state = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), shared_params_optim_state)
            replicated_shared_params_optim_state = jax.device_put_replicated(
                shared_params_optim_state, jax.local_devices())

            # policy stuff
            key, sub_key = jax.random.split(key)
            policy_params = networks.policy_network.init(sub_key)
            policy_optimizer_state = policy_optimizer.init(policy_params)

            # replicated_policy_params = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), policy_params)
            # replicated_policy_optimizer_state = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), policy_optimizer_state)
            replicated_policy_params = jax.device_put_replicated(
                policy_params, jax.local_devices())
            replicated_policy_optimizer_state = jax.device_put_replicated(
                policy_optimizer_state, jax.local_devices())

            state = TrainingState(
                replicated_policy_optimizer_state=
                replicated_policy_optimizer_state,
                replicated_shared_q_optim_state=
                replicated_shared_params_optim_state,
                ensemble_q_optim_state=ensemble_params_optim_state,
                replicated_policy_params=replicated_policy_params,
                replicated_shared_q_params=replicated_shared_params,
                ensemble_q_params=ensemble_params,
                target_replicated_shared_q_params=replicated_shared_params,
                target_ensemble_q_params=ensemble_params,
                key=key,
            )

            # entropy stuff
            if adaptive_entropy_coefficient:
                state = state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=log_alpha)

            # jax.tree_map(lambda t: print(t.shape), replicated_shared_params_optim_state)

            return state
示例#3
0
  def replicate(self):
    """A context manager to use in a with statement that replicates
    the variables in this collection to multiple devices.

    Important: replicating also updates the random state in order
    to have a new one per device.
    """
    global math
    if math is None: from brainpy import math

    replicated, saved_states = {}, {}
    x = jnp.zeros((jax.local_device_count(), 1), dtype=math.float_)
    sharded_x = jax.pmap(lambda x: x, axis_name='device')(x)
    devices = [b.device() for b in sharded_x.device_buffers]
    num_device = len(devices)
    for k, d in self.items():
      if isinstance(d, math.random.RandomState):
        replicated[k] = jax.device_put_sharded([shard for shard in d.split(num_device)], devices)
        saved_states[k] = d.value
      else:
        replicated[k] = jax.device_put_replicated(d.value, devices)
    self.assign(replicated)
    yield
    visited = set()
    for k, d in self.items():
      # Careful not to reduce twice in case of
      # a variable and a reference to it.
      if id(d) not in visited:
        if isinstance(d, math.random.RandomState):
          d.value = saved_states[k]
        else:
          d.value = reduce_func(d)
        visited.add(id(d))
示例#4
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return jax.device_put_replicated(x, jax.local_devices())
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
  def test_pmap_update_nested(self):
    local_device_count = jax.local_device_count()
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x = {
        'a': (jnp.arange(15 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 5),
        'b': (jnp.arange(6 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 2),
    }

    devices = jax.local_devices()
    state = jax.device_put_replicated(state, devices)
    pmap_axis_name = 'i'
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    normalized = jax.pmap(running_statistics.normalize)(x, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
示例#6
0
文件: variable.py 项目: google/objax
 def replicate(self):
     """A context manager to use in a with statement that replicates the variables in this collection to multiple
     devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each
     device.
     Important: replicating also updates the random state in order to have a new one per device.
     """
     replicated, saved_states = [], []
     devices = get_local_devices()
     ndevices = len(devices)
     for v in self:
         if isinstance(v, RandomState):
             replicated.append(jax.device_put_sharded([shard for shard in v.split(ndevices)], devices))
             saved_states.append(v.value)
         else:
             replicated.append(jax.device_put_replicated(v.value, devices))
     self.assign(replicated)
     yield
     visited = set()
     saved_states.reverse()
     for k, v in self.items():
         if isinstance(v, TrainRef):
             v = v.ref
             assert not isinstance(v, TrainRef)
         if id(v) not in visited:  # Careful not to reduce twice in case of a variable and a reference to it.
             if isinstance(v, RandomState):
                 v.assign(saved_states.pop())
             else:
                 v.reduce(v.value)
             visited.add(id(v))
示例#7
0
def opt_state(state, optimizer):
    new_grad_acc, new_opt_state, new_params = opt_jit(state["grad_acc"],
                                                      state["opt_state"],
                                                      state["params"],
                                                      optimizer)

    state["grad_acc"] = new_grad_acc
    state["opt_state"] = new_opt_state
    state["params"] = jax.device_put_replicated(new_params, jax.local_devices())
    state["grad_count"] = np.array(0)
    return state
示例#8
0
        def make_initial_state(key):
            """Initialises the training state (parameters and optimiser state)."""
            key_policy, key_q, key = jax.random.split(key, 3)
            devices = jax.local_devices()

            policy_params = networks.policy_network.init(key_policy)
            policy_optimizer_state = policy_optimizer.init(policy_params)
            policy_params = jax.device_put_replicated(policy_params, devices)
            policy_optimizer_state = jax.device_put_replicated(
                policy_optimizer_state, devices)

            q_params = networks.q_network.init(key_q)
            q_optimizer_state = q_optimizer.init(q_params)
            q_params = jax.device_put_replicated(q_params, devices)
            q_optimizer_state = jax.device_put_replicated(
                q_optimizer_state, devices)

            key, sub_key = jax.random.split(key)
            c_dim = 42  # TODO(kamyar): implement this
            snr_state = snr_utils.snr_state_init(
                c_dim,
                sub_key,
                snr_kwargs,
            )
            snr_state = jax.device_put_replicated(snr_state, devices)

            state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=q_params,
                key=key,
                snr_state=snr_state)

            if adaptive_entropy_coefficient:
                state = state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=log_alpha)
            return state
示例#9
0
    def test_build_coref_positive_negative_mask(self):
        all_mention_target_ids = jax.device_put_replicated(
            self.mention_target_ids_stacked, self.devices)

        get_batch_positions = functools.partial(
            mention_utils.get_globally_consistent_batch_positions,
            batch_size=self.batch_size)
        get_batch_positions = jax.pmap(get_batch_positions, axis_name='batch')

        (local_mention_batch_positions,
         global_mention_batch_positions) = get_batch_positions(
             self.mention_batch_positions_sharded)

        (positive_mask, negative_mask) = jax.pmap(
            mention_losses.build_coref_positive_negative_mask,
            axis_name='batch')(local_mention_batch_positions,
                               global_mention_batch_positions,
                               self.mention_target_ids_sharded,
                               all_mention_target_ids)

        n_all_mentions = self.n_mentions * self.n_devices
        self.assertSequenceEqual(positive_mask.shape, negative_mask.shape)
        self.assertSequenceEqual(
            positive_mask.shape,
            (self.n_devices, self.n_mentions, n_all_mentions))
        positive_mask = positive_mask.reshape(-1, n_all_mentions)
        negative_mask = negative_mask.reshape(-1, n_all_mentions)

        for i in range(n_all_mentions):
            for j in range(n_all_mentions):
                is_same_device = i // self.n_mentions == j // self.n_mentions
                is_same_passage = (self.mention_batch_positions_stacked[i] ==
                                   self.mention_batch_positions_stacked[j])
                is_same_passage = is_same_passage and is_same_device

                if (self.mention_target_ids_stacked[i] == 0
                        or self.mention_target_ids_stacked[j] == 0
                        or is_same_passage):
                    self.assertEqual(positive_mask[i, j], 0)
                    self.assertEqual(negative_mask[i, j], 0)
                    continue

                self.assertEqual(
                    positive_mask[i, j], self.mention_target_ids_stacked[i] ==
                    self.mention_target_ids_stacked[j])
                self.assertEqual(
                    negative_mask[i, j], self.mention_target_ids_stacked[i] !=
                    self.mention_target_ids_stacked[j])
示例#10
0
def tree_replicate_by_name(
    param_tree,
    filter_fn,
    devices = None):
  """Replicates leaf arrays whose name is matched by a filter.

  Args:
    param_tree: Tree of parameters.
    filter_fn: Leaf node filter function.
    devices: XLA devices.

  Returns:
    A tree identical in structure to `param_tree` except that those leaves which
    satisfy `filter_fn` are replicated across devices.
  """
  devices = devices or jax.local_devices()
  return tree_map_with_names(lambda x: jax.device_put_replicated(x, devices),
                             param_tree, filter_fn)
示例#11
0
def main(_):
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0
    save_dir = FLAGS.model_dir if FLAGS.save_dir is None else FLAGS.save_dir
    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine = jax.random.split(rng, 3)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir, FLAGS.config, num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    train_items, val_items, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    step = int(state.step)
    state = jax.device_put_replicated(state, jax.local_devices())

    # TODO: TPU Colab breaks without message if this is a list
    # a list is preferred bc tqdm can show an ETA
    render_dict = {
        "train": zip(range(train_items), train_ds),
        "val": zip(range(val_items), val_ds),
        "test": zip(range(test_items), test_ds),
        "poses": zip(range(num_poses), render_ds),
    }
    render_poses = render_dict[FLAGS.render_video_set]

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    if FLAGS.render_video:
        rgb_list = []
        disp_list = []
        losses = []
        for _, inputs in tqdm(render_poses, desc="Rays render"):
            rays, padding = prepare_render_data(inputs["rays"].numpy())
            preds, *_ = p_eval_step(state, rays)
            preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds)
            rgb_list.append(preds["rgb"])
            disp_list.append(preds["disp"])

            if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render":
                loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                losses.append(loss)

        if FLAGS.config.render_factor == 1 and FLAGS.render_video_set != "render":
            loss = np.mean(losses)
            logging.info("Loss %.5f", loss)
            logging.info("PSNR %.5f", psnr_fn(loss))
        gen_video(save_dir, np.stack(rgb_list), "rgb", r_hwf, step)
        disp = np.stack(disp_list)
        gen_video(save_dir,
                  disp_post(disp, FLAGS.config),
                  "disp",
                  r_hwf,
                  step,
                  ch=1)

    if FLAGS.render_testset:
        test_losses = []
        for idx, inputs in tqdm(zip(range(test_items), test_ds),
                                desc="Test render"):
            rays, padding = prepare_render_data(inputs["rays"].numpy())
            preds, *_ = p_eval_step(state, rays)
            preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds)
            save_test_imgs(save_dir, preds["rgb"], r_hwf, step, idx)

            if FLAGS.config.render_factor == 1:
                loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                test_losses.append(loss)
        if FLAGS.config.render_factor == 1:
            loss = np.mean(test_losses)
            logging.info("Loss %.5f", loss)
            logging.info("PSNR %.5f", psnr_fn(loss))
示例#12
0
    def run_model(self, config, entity_vocab_size):
        """Initialize and run the model once, perform sanity checks."""
        np.random.seed(0)

        # Save arrays to test retrieval saver.
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       self.devices)
        memory_entity_ids = memory_identifiers
        config['memory_entity_id_pattern'] = self.save_sharded_array(
            memory_entity_ids, 'memory_entity_id')
        memory_text = np.random.randint(
            config['model_config']['encoder_config']['vocab_size'],
            size=(self.n_devices, self.table_size, self.memory_text_length),
            dtype=np.int32)
        config['memory_text_pattern'] = self.save_sharded_array(
            memory_text, 'memory_text')
        memory_positions = np.random.randint(self.memory_text_length,
                                             size=(self.n_devices,
                                                   self.table_size, 2),
                                             dtype=np.int32)
        config['memory_positions_pattern'] = self.save_sharded_array(
            memory_positions, 'memory_positions')

        config = ml_collections.FrozenConfigDict(config)
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_collater_fn(
            config)
        postprocess_fn = mention_based_entity_qa_task.MentionBasedEntityQATask.make_output_postprocess_fn(
            config)

        model = mention_based_entity_qa_task.MentionBasedEntityQATask.build_model(
            model_config)
        dummy_input = mention_based_entity_qa_task.MentionBasedEntityQATask.dummy_input(
            config)
        dummy_input = jax.device_put_replicated(dummy_input, self.devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, self.devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, self.devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'encoder': {
                'memory_keys': memory_keys,
                'memory_values': memory_values,
                'memory_identifiers': memory_identifiers,
                'memory_entity_ids': memory_entity_ids,
            }
        }

        def sample_batch():
            processed_examples = []
            for _ in range(config.per_device_batch_size):
                raw_example = test_utils.gen_mention_pretraining_sample(
                    self.text_length,
                    self.n_mentions,
                    self.n_linked_mentions,
                    entity_vocab_size=entity_vocab_size,
                    max_length=encoder_config.max_length)
                processed_example = preprocess_fn(raw_example)
                processed_examples.append(processed_example)
            batch = stack(processed_examples)
            batch = collater_fn(batch)
            batch = {
                key: test_utils.tensor_to_numpy(value)
                for key, value in batch.items()
            }
            return batch

        batch = stack([sample_batch() for _ in range(self.n_devices)])
        batch = {
            key: jax.device_put_sharded(list(value), self.devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(
            mention_based_entity_qa_task.MentionBasedEntityQATask.make_loss_fn(
                config),
            'batch',
            static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {'constants': initial_variables['constants']},
            batch,
            True,
        )

        self.assertArrayEqual(metrics['agg']['denominator'],
                              batch['mention_target_weights'].sum(1))

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])

        n_mentions_per_device = (config.per_device_batch_size *
                                 config.max_mention_targets)
        k_top_final = (encoder_config.final_k_top_post_selection
                       or encoder_config.final_k_top_device * self.n_devices)
        self.assertSequenceEqual(
            np.array(features['memory_text']).shape, [
                self.n_devices, n_mentions_per_device, k_top_final,
                self.memory_text_length
            ])
        self.assertSequenceEqual(
            np.array(features['memory_positions']).shape,
            [self.n_devices, n_mentions_per_device, k_top_final, 2])

        return batch, initial_variables, metrics
示例#13
0
from PIL import Image
import jax
import time

import clip_jax

image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load(
    'ViT-B/32', "cpu")

batch_size = 2048

devices = jax.local_devices()

print(f"jax devices: {devices}")

jax_params = jax.device_put_replicated(jax_params, devices)
image_fn = jax.pmap(image_fn)
text_fn = jax.pmap(text_fn)

jax_image = np.expand_dims(jax_preprocess(Image.open("CLIP.png")), (0, 1))
jax_image = np.repeat(jax_image, len(devices), axis=0)
jax_image = np.repeat(jax_image, batch_size, axis=1)

jax_text = np.expand_dims(clip_jax.tokenize(["a diagram"]), 0)
jax_text = np.repeat(jax_text, len(devices), axis=0)
jax_text = np.repeat(jax_text, batch_size, axis=1)

start = time.time()
jax_image_embed = image_fn(jax_params, jax_image)
jax_text_embed = text_fn(jax_params, jax_text)
total = time.time() - start
示例#14
0
    def run_model(self, config, entity_vocab_size):
        """Initialize and run the model once, perform sanity checks."""
        np.random.seed(0)

        # Save arrays to test retrieval saver.
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       self.devices)
        memory_entity_ids = memory_identifiers
        config['memory_entity_id_pattern'] = self.save_sharded_array(
            memory_entity_ids, 'memory_entity_id')
        memory_text = np.random.randint(
            config['model_config']['encoder_config']['vocab_size'],
            size=(self.n_devices, self.table_size, self.memory_text_length),
            dtype=np.int32)
        config['memory_text_pattern'] = self.save_sharded_array(
            memory_text, 'memory_text')
        memory_positions = np.random.randint(self.memory_text_length,
                                             size=(self.n_devices,
                                                   self.table_size, 2),
                                             dtype=np.int32)
        config['memory_positions_pattern'] = self.save_sharded_array(
            memory_positions, 'memory_positions')

        config = ml_collections.FrozenConfigDict(config)
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(
            config)
        postprocess_fn = mention_memory_task.MentionMemoryTask.make_output_postprocess_fn(
            config)

        model = mention_memory_task.MentionMemoryTask.build_model(model_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, self.devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, self.devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, self.devices)

        # `memory_text_entities` are assumed to contain unique IDs in the last dim.
        memory_text_entities = np.zeros(
            (self.n_devices, self.table_size,
             encoder_config.n_memory_text_entities), np.int32)
        for device_index in range(self.n_devices):
            for t_index in range(self.table_size):
                current_text_entities = np.random.choice(
                    entity_vocab_size,
                    size=(min(encoder_config.n_memory_text_entities,
                              entity_vocab_size)),
                    replace=False)
                memory_text_entities[device_index,
                                     t_index, :len(current_text_entities
                                                   )] = current_text_entities

        memory_text_entities = jax.device_put_sharded(
            list(memory_text_entities), self.devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'encoder': {
                'memory_keys': memory_keys,
                'memory_values': memory_values,
                'memory_identifiers': memory_identifiers,
                'memory_entity_ids': memory_entity_ids,
                'memory_text_entities': memory_text_entities,
            }
        }

        def sample_batch():
            processed_examples = []
            for _ in range(config.per_device_batch_size):
                raw_example = test_utils.gen_mention_pretraining_sample(
                    self.text_length,
                    self.n_mentions,
                    self.n_linked_mentions,
                    entity_vocab_size=entity_vocab_size,
                    max_length=encoder_config.max_length)
                processed_example = preprocess_fn(raw_example)
                processed_examples.append(processed_example)
            batch = stack(processed_examples)
            batch = collater_fn(batch)
            batch = {
                key: test_utils.tensor_to_numpy(value)
                for key, value in batch.items()
            }
            text_ids = batch['text_ids']
            for i in range(config.per_device_batch_size):
                for j in range(config.max_mlm_targets):
                    if batch['mlm_target_weights'][i, j] > 0:
                        text_ids[i, batch['mlm_target_positions'][
                            i, j]] = batch['mlm_target_ids'][i, j]
            mention_batch_positions = batch['mention_batch_positions']
            text_identifiers = batch['text_identifiers'].astype(
                np.int32).tolist()
            expected_text_identifiers = [
                mention_preprocess_utils.text_hash(
                    text_ids[mention_batch_positions[index]]).astype(np.int32)
                for index in range(len(mention_batch_positions))
            ]
            self.assertSequenceEqual(text_identifiers,
                                     expected_text_identifiers)
            return batch

        batch = stack([sample_batch() for _ in range(self.n_devices)])
        batch = {
            key: jax.device_put_sharded(list(value), self.devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(
            mention_memory_task.MentionMemoryTask.make_loss_fn(config),
            'batch',
            static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {'constants': initial_variables['constants']},
            batch,
            True,
        )

        metrics_per_first_device = jax.tree_map(lambda x: x[0], metrics)
        self.assertEqual(metrics_per_first_device['mlm']['denominator'],
                         batch['mlm_target_weights'][0].sum())

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])

        n_mentions_per_device = (config.per_device_batch_size *
                                 config.max_mentions)
        if config.save_k_retrieval is not None:
            k_top_saved = min(config.save_k_retrieval,
                              encoder_config.k_top_post_selection)
        else:
            k_top_saved = (encoder_config.k_top_post_selection
                           or encoder_config.k_top_device * self.n_devices)
        self.assertSequenceEqual(
            np.array(features['memory_text']).shape, [
                self.n_devices, n_mentions_per_device, k_top_saved,
                self.memory_text_length
            ])
        self.assertSequenceEqual(
            np.array(features['memory_positions']).shape,
            [self.n_devices, n_mentions_per_device, k_top_saved, 2])

        if encoder_config.get('num_intermediate_layers') is not None:
            self.assertSequenceEqual(
                np.array(features['second_memory_text']).shape, [
                    self.n_devices, n_mentions_per_device, k_top_saved,
                    self.memory_text_length
                ])
            self.assertSequenceEqual(
                np.array(features['second_memory_positions']).shape,
                [self.n_devices, n_mentions_per_device, k_top_saved, 2])

        return batch, initial_variables, metrics
示例#15
0
 def load_weights(config: ml_collections.ConfigDict) -> Dict[str, Any]:
   """Load model weights from file."""
   params = checkpoint_utils.load_weights(config.load_weights)
   params = jax.device_put_replicated(params, jax.local_devices())
   return {'params': params}
示例#16
0
    def test_model_shape(
        self,
        separate_memory_values=False,
        num_intermediate_layers=None,
    ):
        """Test loss function runs and produces expected values."""
        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config['model_config']['encoder_config'][
            'num_intermediate_layers'] = num_intermediate_layers
        config = ml_collections.FrozenConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config

        rows = encoder_config.rows
        preprocess_fn = mention_memory_task.MentionMemoryTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = mention_memory_task.MentionMemoryTask.make_collater_fn(
            config)

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }

        raw_example = test_utils.gen_mention_pretraining_sample(
            self.text_length,
            self.n_mentions,
            self.n_linked_mentions,
            max_length=encoder_config.max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = {
            key: test_utils.tensor_to_numpy(value)
            for key, value in batch.items()
        }
        batch = {
            key: jax.device_put_replicated(value, devices)
            for key, value in batch.items()
        }

        def model_apply(*args, **kwargs):
            return model.apply(*args, method=model.forward, **kwargs)

        papply = jax.pmap(model_apply, 'batch', static_broadcasted_argnums=(2))
        encoded_output, loss_helpers, _ = papply(
            {
                'params': initial_variables['params'],
                'constants': initial_variables['constants'],
            },
            batch,
            True,
        )

        self.assertEqual(
            encoded_output.shape,
            (self.n_devices, config.per_device_batch_size,
             encoder_config.max_length, encoder_config.hidden_size))

        memory_value_dim = encoder_config.memory_value_dim
        memory_key_dim = encoder_config.memory_key_dim
        memory_size = memory_value_dim if memory_value_dim else memory_key_dim
        self.assertEqual(loss_helpers['target_mention_encodings'].shape,
                         (self.n_devices, config.max_mention_targets *
                          config.per_device_batch_size, memory_size))
示例#17
0
def replicate_model_state(model_states: TrainState) -> TrainState:
  """Replicates the model states."""
  return jax.device_put_replicated(model_states, jax.local_devices())
示例#18
0
    def test_load_weights(self,
                          separate_memory_values=False,
                          memory_only=False):
        """Test saving and loading model recovers original parameters."""

        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'separate_memory_values'] = separate_memory_values
        config = ml_collections.ConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config
        rows = encoder_config.rows
        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()
        model = mention_memory_encoder.MentionMemoryEncoder(**encoder_config)
        dummy_input = mention_memory_task.MentionMemoryTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        memory_table = np.random.rand(rows, self.table_size // rows,
                                      encoder_config.memory_key_dim)
        memory_keys = jax.device_put_replicated(memory_table, devices)
        memory_values = memory_table.reshape(-1, encoder_config.memory_key_dim)
        memory_values = jax.device_put_replicated(memory_values, devices)
        memory_identifiers = np.arange(self.table_size)
        memory_identifiers = jax.device_put_replicated(memory_identifiers,
                                                       devices)
        memory_entity_ids = memory_identifiers
        memory_text_entities = np.zeros(
            (self.table_size, encoder_config.n_memory_text_entities),
            dtype=np.int32)
        memory_text_entities = jax.device_put_replicated(
            memory_text_entities, devices)

        def model_init(*args, **kwargs):
            return model.init(*args, method=model.forward, **kwargs)

        initial_variables = jax.pmap(model_init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        initial_variables = {'params': initial_variables['params']}
        initial_variables['constants'] = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_identifiers': memory_identifiers,
            'memory_entity_ids': memory_entity_ids,
            'memory_text_entities': memory_text_entities,
        }
        n_shards = 4

        tempdir_obj = self.create_tempdir()
        tempdir = tempdir_obj.full_path

        memory_key_base = os.path.join(tempdir, 'memory_keys')
        memory_value_base = os.path.join(tempdir, 'memory_values')
        memory_id_base = os.path.join(tempdir, 'memory_id')
        memory_entity_id_base = os.path.join(tempdir, 'memory_entity_id')
        memory_text_entities_base = os.path.join(tempdir,
                                                 'memory_text_entities')

        unreplicated_variables = jax_utils.unreplicate(initial_variables)
        unreplicated_variables['params'] = unreplicated_variables[
            'params'].unfreeze()

        if memory_only:
            load_weights = 'memory_only'
        else:
            load_weights = os.path.join(tempdir, 'weights')
            checkpoint_utils.save_weights(load_weights,
                                          unreplicated_variables['params'])

        memory_keys = initial_variables['constants']['memory_keys']
        memory_keys = memory_keys.reshape(n_shards, -1,
                                          encoder_config.memory_key_dim)
        memory_values = initial_variables['constants']['memory_values']
        memory_values = memory_values.reshape(n_shards, -1,
                                              encoder_config.memory_key_dim)
        memory_ids = initial_variables['constants'][
            'memory_identifiers'].reshape(n_shards, -1)
        memory_entity_ids = initial_variables['constants'][
            'memory_entity_ids'].reshape(n_shards, -1)
        memory_text_entities = initial_variables['constants'][
            'memory_text_entities'].reshape(
                n_shards, -1, encoder_config.n_memory_text_entities)

        for shard in range(n_shards):
            np.save(memory_key_base + str(shard), memory_keys[shard])
            np.save(memory_value_base + str(shard), memory_values[shard])
            np.save(memory_id_base + str(shard), memory_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_entity_id_base + str(shard),
                    memory_entity_ids[shard])
            np.save(memory_text_entities_base + str(shard),
                    memory_text_entities[shard])

        config.memory_key_pattern = memory_key_base + '*'
        config.memory_value_pattern = memory_value_base + '*'
        config.memory_id_pattern = memory_id_base + '*'
        config.memory_entity_id_pattern = memory_entity_id_base + '*'
        config.memory_text_entities_pattern = memory_text_entities_base + '*'
        config.load_weights = load_weights

        loaded_variables = mention_memory_encoder.MentionMemoryEncoder.load_weights(
            config)

        arrayeq = lambda x, y: jnp.all(x == y)
        constants = {
            key: value
            for key, value in initial_variables['constants'].items()
            if not (key == 'memory_values' and not separate_memory_values)
        }
        comparison_variables = {'constants': constants}
        if not memory_only:
            comparison_variables['params'] = initial_variables[
                'params'].unfreeze()

        self.assertTrue(
            jax.tree_map(arrayeq, loaded_variables, comparison_variables))
示例#19
0
文件: main.py 项目: myagues/flax_nerf
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)
示例#20
0
    def test_loss_fn(
        self,
        k_top,
        num_intermediate_layers=None,
        shared_initial_encoder=True,
        shared_intermediate_encoder=True,
        shared_final_encoder=True,
        no_retrieval=False,
        same_passage_retrieval_policy='allow',
        extract_unlinked_mentions=False,
        no_retrieval_for_masked_mentions=False,
    ):
        """Test loss function runs and produces expected values."""
        config = copy.deepcopy(self.config)
        encoder_config = copy.deepcopy(self.encoder_config)
        encoder_config['k_top'] = k_top
        encoder_config['num_intermediate_layers'] = num_intermediate_layers
        encoder_config['shared_initial_encoder'] = shared_initial_encoder
        encoder_config[
            'shared_intermediate_encoder'] = shared_intermediate_encoder
        encoder_config['shared_final_encoder'] = shared_final_encoder
        encoder_config['no_retrieval'] = no_retrieval
        encoder_config[
            'same_passage_retrieval_policy'] = same_passage_retrieval_policy
        encoder_config['extract_unlinked_mentions'] = extract_unlinked_mentions
        encoder_config[
            'no_retrieval_for_masked_mentions'] = no_retrieval_for_masked_mentions
        config['model_config']['encoder_config'] = encoder_config
        if no_retrieval:
            config['el_im_weight'] = 0
        if num_intermediate_layers is not None:
            config['second_el_im_weight'] = 0.1
        config = ml_collections.FrozenConfigDict(config)

        model_config = config.model_config
        encoder_config = model_config.encoder_config

        preprocess_fn = readtwice_task.ReadTwiceTask.make_preprocess_fn(config)  # pylint: disable=line-too-long
        collater_fn = readtwice_task.ReadTwiceTask.make_collater_fn(config)
        postprocess_fn = readtwice_task.ReadTwiceTask.make_output_postprocess_fn(
            config)

        test_utils.force_multi_devices(self.n_devices)
        devices = jax.local_devices()

        model = readtwice_task.ReadTwiceTask.build_model(model_config)
        dummy_input = readtwice_task.ReadTwiceTask.dummy_input(config)
        dummy_input = jax.device_put_replicated(dummy_input, devices)
        init_rng = jax.random.PRNGKey(0)
        split_rng = jax.random.split(init_rng, self.n_devices)

        initial_variables = jax.pmap(model.init,
                                     'batch',
                                     static_broadcasted_argnums=2)(
                                         split_rng,
                                         dummy_input,
                                         True,
                                     )
        raw_example = test_utils.gen_mention_pretraining_sample(
            self.text_length,
            self.n_mentions,
            self.n_linked_mentions,
            max_length=encoder_config.max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = {
            key: test_utils.tensor_to_numpy(value)
            for key, value in batch.items()
        }
        batch = {
            key: jax.device_put_replicated(value, devices)
            for key, value in batch.items()
        }

        loss_fn = jax.pmap(readtwice_task.ReadTwiceTask.make_loss_fn(config),
                           'batch',
                           static_broadcasted_argnums=(0, 4))
        _, metrics, auxiliary_output = loss_fn(
            model_config,
            initial_variables['params'],
            {},  # model vars
            batch,
            True,  # deterministic
        )

        take_first = lambda x: x[0]
        metrics = jax.tree_map(take_first, metrics)
        np_batch = jax.tree_map(take_first, batch)

        # mlm losses
        expected_mlm_denom = np_batch['mlm_target_weights'].sum()
        expected_mlm_mention_denom = (np_batch['mlm_target_weights'] *
                                      np_batch['mlm_target_is_mention']).sum()
        expected_mlm_non_mention_denom = (
            np_batch['mlm_target_weights'] *
            (1 - np_batch['mlm_target_is_mention'])).sum()
        self.assertEqual(metrics['mlm']['denominator'], expected_mlm_denom)
        self.assertEqual(metrics['mlm_mention']['denominator'],
                         expected_mlm_mention_denom)
        self.assertEqual(metrics['mlm_non_mention']['denominator'],
                         expected_mlm_non_mention_denom)
        self.assertEqual(metrics['mlm_first']['denominator'],
                         expected_mlm_denom)
        self.assertEqual(metrics['mlm_mention_first']['denominator'],
                         expected_mlm_mention_denom)
        self.assertEqual(metrics['mlm_non_mention_first']['denominator'],
                         expected_mlm_non_mention_denom)

        # same entity retrieval loss
        if not no_retrieval:
            expected_same_entity_denom = np_batch[
                'mention_target_weights'].sum()
            self.assertEqual(metrics['el_intermediate']['denominator'],
                             expected_same_entity_denom)
            if num_intermediate_layers is not None:
                self.assertEqual(
                    metrics['second_el_intermediate']['denominator'],
                    expected_same_entity_denom)

        # coref losses
        expected_coref_denom = np_batch['mention_target_weights'].sum()
        expected_coref_masked_denom = (
            np_batch['mention_target_weights'] *
            np_batch['mention_target_is_masked']).sum()
        expected_coref_non_masked_denom = (
            np_batch['mention_target_weights'] *
            (1 - np_batch['mention_target_is_masked'])).sum()

        for coref_type in {'key', 'value', 'final'}:
            self.assertEqual(
                metrics[coref_type + '_coref_resolution']['denominator'],
                expected_coref_denom)
            self.assertEqual(
                metrics[coref_type +
                        '_coref_resolution_masked']['denominator'],
                expected_coref_masked_denom)
            self.assertEqual(
                metrics[coref_type +
                        '_coref_resolution_non_masked']['denominator'],
                expected_coref_non_masked_denom)

        # mtb losses
        for mtb_type in {'key', 'value', 'final'}:
            self.assertIn(mtb_type + '_mtb', metrics)
            self.assertIn(mtb_type + '_mtb_masked', metrics)
            self.assertIn(mtb_type + '_mtb_non_masked', metrics)

        features = postprocess_fn(batch, auxiliary_output)
        # Check features are JSON-serializable
        json.dumps(features)
        # Check features match the original batch
        for key in batch.keys():
            self.assertArrayEqual(np.array(features[key]), batch[key])
示例#21
0
  def test_mention_memory_layer(self, separate_memory_values):
    """Testing memory attention layer."""

    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    static_argnums = (9) if separate_memory_values else (9, 10)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=static_argnums)

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    n_mentions = mention_start_positions.shape[-1]

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)
    # Make sure id 0 or 1 will be highest scoring
    memory_table[0] = memory_table[0] * 2.0
    memory_table[1] = memory_table[1] * -2.0
    memory_table = jnp.asarray(memory_table, dtype=self.dtype)

    memory_keys = memory_table.reshape(self.n_devices, self.rows,
                                       self.table_size // self.rows,
                                       self.memory_key_dim)

    memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices)
    if separate_memory_values:
      memory_values = memory_table.reshape(self.n_devices, self.table_size,
                                           self.memory_key_dim)
      memory_values = jax.device_put_sharded(list(memory_values), devices)
    else:
      memory_values = None

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    (encoded_output, loss_helpers, _), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,
        text_identifiers=None,
    )

    attention_weights = loss_helpers['memory_attention_weights']
    entity_ids = loss_helpers['top_entity_ids']

    normed_input = encoded_input - 1.0

    # Check input was changed
    self.assertFalse(jnp.allclose(encoded_output, normed_input))

    # Check input was not changed where it should not be
    all_indices = set(
        itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
    # Note that mention positions is the same across all of the devices
    start_indices = set(
        zip(mention_batch_positions[0].tolist(),
            mention_start_positions[0].tolist()))
    non_start_indices = all_indices.difference(start_indices)
    non_start_indices_1, non_start_indices_2 = zip(*non_start_indices)
    non_start_indices_1 = jnp.asarray(non_start_indices_1)
    non_start_indices_2 = jnp.asarray(non_start_indices_2)

    non_start_outputs = encoded_output[:, non_start_indices_1,
                                       non_start_indices_2]
    non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2]
    self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs))

    # Check shapes as expected
    self.assertSequenceEqual(
        encoded_output.shape,
        (self.n_devices, self.bsz, self.seq_len, self.input_dim))

    self.assertSequenceEqual(
        attention_weights.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    self.assertSequenceEqual(
        entity_ids.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    # Check id 0 or 1 retrieved
    self.assertTrue(
        jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1)))

    # Set some text identifiers to 0 and others to 1 so that some are binding
    text_identifiers = np.zeros((n_mentions), dtype=np.int32)
    text_identifiers[:n_mentions // 2] = 1
    text_identifiers = jax.device_put_replicated(text_identifiers, devices)

    # Initialize and run one forward pass of model
    (_, loss_helpers, logging_helpers), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,  # memory_values
        text_identifiers=text_identifiers,
    )
    attention_weights_wid = loss_helpers['memory_attention_weights']
    entity_ids_wid = loss_helpers['top_entity_ids']
    n_disallowed = logging_helpers['n_disallowed'][0]

    # Check no effect on ids
    self.assertTrue(jnp.all(entity_ids == entity_ids_wid))

    # Check id 0 or 1 have 0 scores
    text_identifiers = jnp.expand_dims(text_identifiers, -1)
    score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid
    self.assertAlmostEqual(score_masked.sum(), 0.0)

    # Check number disallowed as expected
    self.assertEqual(n_disallowed, n_mentions // 2)
示例#22
0
  def test_memory_attention_backward(self):
    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    pinit = jax.pmap(
        model.init, axis_name='batch', static_broadcasted_argnums=(9, 10))

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.asarray([[0], [1], [2]]), (1, self.bsz)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)

    memory_table = jnp.asarray(memory_table, dtype=self.dtype)
    memory_table = memory_table.reshape(self.n_devices, self.rows,
                                        self.table_size // self.rows,
                                        self.memory_key_dim)
    memory_table_sharded = jax.device_put_sharded(list(memory_table), devices)

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    initial_parameters = pinit(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_table_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        None,  # memory_values
        text_identifiers=None,
    )

    def step_fn(
        params,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys,
        memory_identifiers,
        memory_entity_ids,
    ):

      def loss_fn(params):
        encoded_output, _, _ = model.apply(
            {'params': params},
            rngs=None,
            encoded_input=encoded_input,
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            memory_keys=memory_keys,
            memory_identifiers=memory_identifiers,
            memory_entity_ids=memory_entity_ids,
            deterministic=True,
            text_identifiers=None,
        )
        return encoded_output.sum()

      loss, grad = jax.value_and_grad(loss_fn)(params)
      return loss, grad

    pstep = jax.pmap(step_fn, axis_name='batch')

    _ = pstep(
        initial_parameters['params'],
        encoded_input=encoded_input,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        memory_keys=memory_table_sharded,
        memory_identifiers=memory_identifiers,
        memory_entity_ids=memory_entity_ids,
    )