Пример #1
0
 def apply_gradient(self, hyper_params, params, state, grads):
     grads = jax.tree_map(self._cross_replica_mean, grads)
     return self.optimizer_def.apply_gradient(hyper_params, params, state,
                                              grads)
Пример #2
0
 def test_frozen_dict_maps(self):
     xs = {'a': 1, 'b': {'c': 2}}
     frozen = FrozenDict(xs)
     frozen2 = jax.tree_map(lambda x: x + x, frozen)
     self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
Пример #3
0
 def update_fn(updates, state, params=None):
     del params  # unused by the test optimizer
     aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
     updates = jax.tree_map(lambda u: step_size * u, updates)
     return updates, TestOptimizerState(aggregate_grads, is_reset=False)
Пример #4
0
  def beam_search_loop_body_fn(state):
    """Beam search loop state update function."""
    # Collect the current position slice along length to feed the fast
    # autoregressive decoder model.  Flatten the beam dimension into batch
    # dimension for feeding into the model.
    # --> [batch * beam, 1]
    flat_ids = flatten_beam_dim(lax.dynamic_slice(
        state.live_seqs,
        (0, 0, state.cur_index),
        (batch_size, beam_size, 1)))
    # Flatten beam dimension into batch to be compatible with model.
    # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
    flat_cache = jax.tree_map(
        lambda x: flatten_beam_dim(x, batch_size), state.cache)

    # Call fast-decoder model on current tokens to get next-position logits.
    # --> [batch * beam, vocab]
    flat_logits, new_flat_cache = tokens_to_logits(
        flat_ids, flat_cache, jax.random.PRNGKey(state.cur_index))

    # unflatten beam dimension
    # [batch * beam, vocab] --> [batch, beam, vocab]
    logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
    # Unflatten beam dimension in attention cache arrays
    # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
    new_cache = jax.tree_map(
        lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

    # Gather log probabilities from logits
    candidate_log_probs = jax.nn.log_softmax(logits)
    # Add new logprobs to existing prefix logprobs.
    # --> [batch, beam, vocab]
    log_probs = (candidate_log_probs +
                 jnp.expand_dims(state.live_logprobs, axis=2))

    # We'll need the vocab size, gather it from the log probability dimension.
    vocab_size = log_probs.shape[2]

    # Each item in batch has beam_size * vocab_size candidate sequences.
    # For each item, get the top 2*k candidates with the highest log-
    # probabilities. We gather the top 2*K beams here so that even if the best
    # K sequences reach EOS simultaneously, we have another K sequences
    # remaining to continue the live beam search.
    beams_to_keep = 2 * beam_size
    # Flatten beam and vocab dimensions.
    flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
    # Gather the top 2*K scores from _all_ beams.
    # --> [batch, 2*beams], [batch, 2*beams]
    topk_log_probs, topk_indices = top_k(flat_log_probs, k=beams_to_keep)
    # Recover the beam index by floor division.
    topk_beam_indices = topk_indices // vocab_size
    # Gather 2*k top beams and beam-associated caches.
    # --> [batch, 2*beams, length], {[batch, 2*beams, ...], ...}
    topk_seq, new_cache = gather_beams([state.live_seqs, new_cache],
                                       topk_beam_indices,
                                       batch_size, beams_to_keep)

    # Append the most probable 2*K token IDs to the top 2*K sequences
    # Recover token id by modulo division and expand Id array for broadcasting.
    # --> [batch, 2*beams, 1]
    topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
    # Update sequences for the 2*K top-k new sequences.
    # --> [batch, 2*beams, length]
    topk_seq = lax.dynamic_update_slice(
        topk_seq, topk_ids, (0, 0, state.cur_index + 1))

    # Update LIVE (in-progress) sequences:
    # Did any of these sequences reach an end marker?
    # --> [batch, 2*beams]
    newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
    # To prevent these newly finished sequences from being added to the LIVE
    # set of active beam search sequences, set their log probs to a very large
    # negative value.
    new_log_probs = topk_log_probs + newly_finished * NEG_INF
    # --> [batch, beams, length], [batch, beams], {[batch, beams, ...], ...}
    top_alive_seq, top_alive_log_probs, top_alive_cache = gather_topk_beams(
        [topk_seq, new_log_probs, new_cache],
        new_log_probs,
        batch_size, beam_size)

    # Update FINISHED (reached end of sentence) sequences:
    # Calculate new seq scores from log probabilities.
    new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
    # Mask out the still unfinished sequences by adding large negative value.
    # --> [batch, 2*beams]
    new_scores += (~newly_finished) * NEG_INF

    # Combine sequences, scores, and flags along the beam dimension and compare
    # new finished sequence scores to existing finished scores and select the
    # best from the new set of beams.
    finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
        [state.finished_seqs, topk_seq], axis=1)
    finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
        [state.finished_scores, new_scores], axis=1)
    finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
        [state.finished_flags, newly_finished], axis=1)
    # --> [batch, beams, length], [batch, beams], [batch, beams]
    top_finished_seq, top_finished_scores, top_finished_flags = (
        gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                          finished_scores, batch_size, beam_size))

    return BeamState(cur_index=state.cur_index + 1,
                     live_logprobs=top_alive_log_probs,
                     finished_scores=top_finished_scores,
                     live_seqs=top_alive_seq,
                     finished_seqs=top_finished_seq,
                     finished_flags=top_finished_flags,
                     cache=top_alive_cache)
Пример #5
0
    def new_weights(self, input_signature):
        weights = super().new_weights(input_signature)
        if self.init_checkpoint is None:
            return weights

        print('Loading pre-trained weights from', self.init_checkpoint)
        ckpt = tf.train.load_checkpoint(self.init_checkpoint)

        def reshape_qkv(name):
            x = ckpt.get_tensor(name)
            return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1)

        def reshape_o(name):
            x = ckpt.get_tensor(name)
            return x.reshape((-1, 64, x.shape[-1]))

        def reshape_bias(name):
            x = ckpt.get_tensor(name)
            return x.reshape((-1, 64))

        new_w = [
            ckpt.get_tensor('bert/embeddings/word_embeddings'),
            ckpt.get_tensor('bert/embeddings/token_type_embeddings'),
            ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...],
            ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'),
            ckpt.get_tensor('bert/embeddings/LayerNorm/beta'),
        ]

        for i in range(12):  # 12 layers
            new_w += [
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/query/kernel'),
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/key/kernel'),
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/value/kernel'),
                reshape_o(
                    f'bert/encoder/layer_{i}/attention/output/dense/kernel'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/query/bias'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/key/bias'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/value/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/dense/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'
                ),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/intermediate/dense/kernel'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/intermediate/dense/bias'),
                ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'),
                ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/output/LayerNorm/gamma'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/output/LayerNorm/beta'),
            ]

        new_w += [
            ckpt.get_tensor('bert/pooler/dense/kernel'),
            ckpt.get_tensor('bert/pooler/dense/bias'),
        ]

        for a, b in zip(fastmath.tree_leaves(weights), new_w):
            assert a.shape == b.shape, (
                f'Expected shape {a.shape}, got shape {b.shape}')
        weights = jax.tree_unflatten(jax.tree_structure(weights), new_w)
        move_to_device = jax.jit(lambda x: x)
        weights = jax.tree_map(move_to_device, weights)
        return weights
Пример #6
0
def shard(xs):
  """Split data into shards for multiple devices along the first dimension."""
  return jax.tree_map(
      lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
Пример #7
0
def evaluate_sequence_accuracy(p_pred_step,
                               p_init_cache,
                               state,
                               ds,
                               config,
                               split,
                               workdir,
                               num_eval_steps=-1):
    """Evaluate classification on the given dataset."""
    prediction_dir = os.path.join(workdir, 'predictions')
    tf.io.gfile.makedirs(prediction_dir)
    logging.info('Starting evaluating sequence accuracy on %s split.', split)
    outputs = []

    test_metrics = collections.defaultdict(list)

    data_dir = config.dataset.data_dir
    input_vocab_file = os.path.join(data_dir, 'training_input_vocab.txt')
    target_vocab_file = os.path.join(data_dir, 'training_target_vocab.txt')
    dataset_file = os.path.join(data_dir, 'dataset.txt')
    eos_idx = config.dataset.eos_idx

    with tf.io.gfile.GFile(input_vocab_file, 'r') as f:
        input_vocab = json.load(f)
    with tf.io.gfile.GFile(target_vocab_file, 'r') as f:
        target_vocab = json.load(f)
    with tf.io.gfile.GFile(dataset_file, 'r') as f:
        annotations = json.load(f)

    for step, batch in enumerate(ds):  # pytype: disable=wrong-arg-types
        batch = jax.tree_map(np.asarray, batch)
        cache = p_init_cache(batch)
        batch['predictions'] = p_pred_step(batch, state, cache, eos_idx)
        batch = remove_pad(tohost(batch))
        target_token = batch['target_token']
        predictions = batch['predictions']
        for i, (prediction, target) in enumerate(zip(predictions,
                                                     target_token)):
            prediction = remove_special_tokens(prediction.tolist(), eos_idx)
            target = remove_special_tokens(target.tolist(), eos_idx)
            acc = evaluation.sequence_accuracy(prediction, target)
            test_metrics['test_accuracy'].append(acc)
            exact_match = 100 if acc == 100 else 0
            test_metrics['test_exact_match'].append(exact_match)

            input_command = remove_special_tokens(batch['token'][i].tolist(),
                                                  eos_idx)
            index = int(batch['index'][i][0])
            example = annotations['examples'][split][index]
            outputs.append({
                'split':
                split,
                'index':
                index,
                'input':
                array_to_sentence(input_command, input_vocab),
                'prediction':
                array_to_sentence(prediction, target_vocab),
                'target':
                array_to_sentence(target, target_vocab),
                'derivation': [example['derivation']],
                'situation': [example['situation']],
                'accuracy':
                acc,
                'exact_match':
                True if acc == 100 else False,
                'attention_weights_input': [],
                'attention_weights_situation': [],
            })
        if num_eval_steps > 0 and step + 1 == num_eval_steps:
            break
    test_metrics = {k: sum(v) / len(v) for k, v in test_metrics.items()}
    step = flax_utils.unreplicate(state).step
    out_path = os.path.join(prediction_dir, f'{split}_predict_{step}.json')
    with tf.io.gfile.GFile(out_path, 'w') as f:
        json.dump(outputs, f, indent=2)
    return test_metrics
Пример #8
0
 def initial_state(self, batch_size):
     return jax.tree_map(jnp.zeros_like, self._state)
Пример #9
0
def prepare_tf_data_unbatched(xs):
  """Prepare TF dataset into unbatched numpy arrays."""
  # Use _numpy() for zero-copy conversion between TF and NumPy.
  # pylint: disable=protected-access
  return jax.tree_map(lambda x: x._numpy(), xs)
Пример #10
0
 def post_pmap(xs):
   return jax.tree_map(lambda x: x[0], xs)
Пример #11
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.makedirs(FLAGS.save_dir)

  hparam_str_dict = json.loads(FLAGS.xm_parameters)
  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in hparam_str_dict.keys()])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  predict_io_shape = (FLAGS.per_device_batch_size,
                      FLAGS.num_strings_per_task,
                      FLAGS.predict_max_characters)
  target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length)
  indices_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_strings_per_task)
  program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)  # pylint: disable=unused-variable

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']:
    spec_vocab = robust_fill_dsl.CHARACTER + '|'
    spec_id_token_table = {i+3: token
                           for i, token in enumerate(spec_vocab)}
    bos_id = 1
    eos_id = 2
    spec_id_token_table[bos_id] = robust_fill_dsl.BOS
    spec_id_token_table[eos_id] = robust_fill_dsl.EOS
    spec_token_id_table = {token: id
                           for id, token in spec_id_token_table.items()}
    spec_vocab_size = len(spec_token_id_table) + 1  # For padding.
    program_id_token_table, program_token_id_table = (
        dsl_tokens.build_token_tables())
    program_vocab_size = len(program_id_token_table) + 1  # pylint: disable=unused-variable
  elif FLAGS.dataset_type == 'scan':
    # TODO(jxihong): Scan is not handled yet.
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))
  else:
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    if FLAGS.dataset_type == 'robust_fill':
      def decode_str(s):
        """Decode string tokens."""
        return ''.join([spec_id_token_table[t_id] for t_id in s if t_id > 0])

      inps, outs = [], []
      for inp, out in zip(inputs, outputs):
        inps.append(decode_str(inp))
        outs.append(decode_str(out))
      return inps, outs

    elif FLAGS.dataset_type == 'scan':
      def decode_str(s):
        """Decode string tokens."""
        return ' '.join([spec_id_token_table[t_id] for t_id in s if t_id > 0])

      inps = [decode_str(inp) for inp in inputs]
      dummy_outs = [''] * len(inps)
      return inps, dummy_outs

    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_target(target):
    """Decode program tokens."""
    target = target[:np.argmax(target == eos_id)].astype(np.int32)

    if FLAGS.dataset_type == 'robust_fill':
      target = target[target != bos_id].tolist()
      return ''.join([spec_id_token_table[t_id] for t_id in target if t_id > 0])
    elif FLAGS.dataset_type == 'scan':
      # TODO(jxihong): Scan is not handled yet.
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))
    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_program(program):
    """Decode program tokens."""
    program = program[:np.argmax(program == eos_id) + 1].astype(np.int32)

    if FLAGS.dataset_type == 'robust_fill':
      # Returns either a Concat program object, or None.
      program = program[program != bos_id].tolist()
      try:
        return robust_fill_dsl.decode_program(program, program_id_token_table)
      except:  # pylint: disable=bare-except
        return None  # Program does not compile.
    elif FLAGS.dataset_type == 'scan':
      # Returns a string.
      program = program[jnp.logical_and(program != bos_id,
                                        program != eos_id)].tolist()
      return ' '.join(scan_vocab.decode(program, program_id_token_table))
    else:
      raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  def decode_program_str(program):  # pylint: disable=unused-variable
    """Decode program tokens into a string."""
    decoded = decode_program(program)
    if FLAGS.dataset_type == 'robust_fill':
      try:
        return decoded.to_string()
      except:  # pylint: disable=bare-except
        return 'did not compile'
    else:
      assert isinstance(decoded, str), '{} should be string'.format(decoded)
      return decoded

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
  padded_shapes = {
      'inputs': io_shape[1:],
      'outputs': io_shape[1:],
      'spec_parts': target_shape[1:],
      'start_index': indices_shape[1:],
      'end_index': indices_shape[1:],
      # TODO(kshi): include programs.
      # 'program': program_shape[1:],
  }
  logging.info('padded_shapes: %s', padded_shapes)

  if FLAGS.dataset_type == 'robust_fill':
    create_dataset_fn = input_pipeline.create_robust_fill_dataset_from_tf_record
  elif FLAGS.dataset_type == 'scan':
    raise NotImplementedError()  # TODO(kshi): Implement.
    # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record
  else:
    raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))

  dataset = create_dataset_fn(
      FLAGS.dataset_filepattern, program_token_id_table, spec_token_id_table,
      max_target_length=FLAGS.max_target_length)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=padded_shapes,
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_padded_shapes = padded_shapes.copy()
  predict_padded_shapes['inputs'] = predict_io_shape[1:]
  predict_padded_shapes['outputs'] = predict_io_shape[1:]
  logging.info('predict_padded_shapes: %s', predict_padded_shapes)
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=predict_padded_shapes)
  train_ds = dataset.skip(FLAGS.num_eval_steps)
  if FLAGS.train_set_batches > 0:
    train_ds = train_ds.take(FLAGS.train_set_batches)
  train_ds = train_ds.repeat()

  test_dataset = create_dataset_fn(
      FLAGS.test_dataset_filepattern, program_token_id_table,
      spec_token_id_table, max_target_length=FLAGS.max_target_length)
  test_dataset = test_dataset.padded_batch(
      batch_size,
      padded_shapes=predict_padded_shapes,
      drop_remainder=False)
  quick_test_dataset = (test_dataset
                        .take(FLAGS.num_quick_test_steps)
                        .unbatch()
                        .padded_batch(int(np.ceil(batch_size / 10)),
                                      padded_shapes=predict_padded_shapes))
  final_test_dataset = (test_dataset
                        .take(FLAGS.num_final_test_steps)
                        .unbatch()
                        .padded_batch(int(np.ceil(batch_size / 10)),
                                      padded_shapes=predict_padded_shapes))

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  base_config = base_models.TransformerConfig(
      vocab_size=spec_vocab_size,
      output_vocab_size=spec_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      dropout_rate=FLAGS.dropout_rate,
      attention_dropout_rate=FLAGS.attention_dropout_rate,
      use_relative_attention=FLAGS.use_relative_attention,
      deterministic=False,
      decode=False,
      bos_token=bos_id,
      num_input_relative_position_buckets=FLAGS.num_position_buckets,
      max_input_distance=FLAGS.max_distance,
      num_output_relative_position_buckets=FLAGS.num_position_buckets,
      max_output_distance=FLAGS.max_distance,
      num_input_cross_output_relative_position_buckets=(
          FLAGS.num_position_buckets),
      max_input_cross_output_distance=FLAGS.max_distance,
      num_program_relative_position_buckets=FLAGS.num_position_buckets,
      max_program_distance=FLAGS.max_distance,
      num_program_cross_embed_relative_position_buckets=(
          FLAGS.num_position_buckets),
      max_program_cross_embed_distance=FLAGS.max_distance,
      bidirectional_program_attention=FLAGS.bidirectional_program_attention)
  train_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config,
      dataset_type=FLAGS.dataset_type)
  eval_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config.replace(deterministic=True),
      dataset_type=FLAGS.dataset_type)
  predict_config = models.DecomposeAttentionTransformerConfig(
      base_config=base_config.replace(
          shift=False, deterministic=True,
          decode=not FLAGS.slow_decode,
          max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)),
      dataset_type=FLAGS.dataset_type)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  dropout_rng = jax.random.split(rng, jax.local_device_count())
  del rng

  m = models.DecomposeAttentionTransformer(eval_config)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(target_shape, jnp.float32))

  optimizer_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if FLAGS.finetune_start_step > 0:
      logging.info('Checking that start_step (%s) == finetune_start_step (%s)',
                   start_step, FLAGS.finetune_start_step)
      assert start_step >= FLAGS.finetune_start_step
      steps_to_skip = start_step - FLAGS.finetune_start_step
    else:
      steps_to_skip = start_step

    # TODO(kshi): It is likely that this code can lead to the job stalling for
    # 10+ hours when restarting from a checkpoint that had been trained a long
    # time, possibly because dataset skipping is slow.
    logging.info('Skipping %s steps...', steps_to_skip)
    train_ds = train_ds.skip(steps_to_skip)
    dummy_p_train_step = jax.pmap(
        lambda dropout_rng: jax.random.split(dropout_rng)[1])
    for _ in range(steps_to_skip):
      dropout_rng = dummy_p_train_step(dropout_rng)
    logging.info('Finished skipping steps')
    logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  # TODO(jxihong): Implement fast decoding.
  assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

  if FLAGS.finetune_start_step <= 0:
    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
  else:
    # Constant LR for finetuning.
    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr,
        factors='constant')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          learning_rate_fn=learning_rate_fn,
          config=train_config),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(eval_step,
                        eos_token=eos_id,
                        config=eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=FLAGS.max_target_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          eos_token=eos_id,
          max_decode_len=FLAGS.max_target_length,
          config=predict_config,
          slow_decode=FLAGS.slow_decode),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------

  logging.info('Starting training!')
  metrics_all = []
  tick = time.time()
  train_iter = train_ds.as_numpy_iterator()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, targets = load_data(next(train_iter))

    optimizer, metrics, dropout_rng = p_train_step(
        optimizer, inputs, outputs, targets, dropout_rng=dropout_rng)
    metrics_all.append(metrics)
    is_last_step = step == FLAGS.num_train_steps - 1

    # Periodic metric handling.

    # Training Metrics
    if (step and step % FLAGS.log_freq == 0) or is_last_step:
      logging.info('Gathering training metrics.')
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(
          lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
          metrics_sums)
      summary['learning_rate'] = lr
      # Calculate (clipped) perplexity after averaging log-perplexities:
      summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

      if jax.host_id() == 0:
        logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
        tock = time.time()
        steps_per_sec = FLAGS.log_freq / (tock - tick)
        tick = tock
        summary_writer.scalar('train/steps per second', steps_per_sec, step)
        for key, val in summary.items():
          summary_writer.scalar('train/' + key, val, step)
        summary_writer.flush()
      # Reset metric accumulation for next evaluation cycle.
      metrics_all = []

    # Evaluation Metrics
    if (step and step % FLAGS.eval_freq == 0) or is_last_step:
      logging.info('Gathering evaluation metrics.')
      t_evaluation_start = time.time()
      eval_metrics = []
      for batches in eval_ds.as_numpy_iterator():
        inputs, outputs, targets = load_data(batches)

        metrics = p_eval_step(optimizer.target, inputs, outputs, targets)
        eval_metrics.append(metrics)

      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      if jax.host_id() == 0:
        logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                     time.time()-t_evaluation_start, step, eval_summary['loss'])
        for key, val in eval_summary.items():
          summary_writer.scalar('eval/' + key, val, step)
        summary_writer.flush()

    # Beam search metrics.
    if (step and step % FLAGS.predict_freq == 0) or is_last_step:
      logging.info('Gathering beam search metrics.')
      test_ds = final_test_dataset if is_last_step else quick_test_dataset

      for dataset, predict_or_test in [(predict_ds, 'predict'),
                                       (test_ds, 'test')]:

        for beam_size in [1, 10]:
          t_inference_start = time.time()
          total_successes = 0
          total_denominator = 0

          ios, targets_list, predictions, top_of_beams = [], [], [], []
          for batches in dataset.as_numpy_iterator():
            pred_batch = batches
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
              padded_size = int(
                  np.ceil(cur_pred_batch_size / n_devices) * n_devices)
              # pylint: disable=cell-var-from-loop
              pred_batch = jax.tree_map(
                  lambda x: pad_examples(x, padded_size), pred_batch)
            inputs, outputs, targets = load_data(pred_batch)

            cache = (p_init_cache(inputs, outputs, targets)
                     if not FLAGS.slow_decode else None)
            predicted = p_pred_step(optimizer.target, inputs, outputs, cache,
                                    beam_size)
            predicted = tohost(predicted)
            inputs, outputs, targets = map(tohost, (inputs, outputs, targets))

            for i, beams in enumerate(predicted):
              inps, outs = decode_io(inputs[i], outputs[i])
              ground_truth = decode_target(targets[i])
              beams_target = [decode_target(beam) for beam in beams]
              predicted_target = beams_target[0]
              for beam_target in beams_target:
                if beam_target == ground_truth:
                  predicted_target = beam_target
                  total_successes += 1
                  break
              total_denominator += 1

              ios.append(' ; '.join(map(str, zip(inps, outs))))
              targets_list.append(ground_truth)
              predictions.append(predicted_target)
              logging.info('')
              logging.info('ios: %s', ios[-1])
              logging.info('targets[%s]: %s', i, targets[i])
              logging.info('ground_truth: %s', ground_truth)
              logging.info('predicted beam: %s', '\n'.join(beams_target))
              logging.info('predicted_target: %s', predicted_target)
              logging.info('beams: %s', beams)

              if not ground_truth:
                logging.warn('ground_truth is empty!')

              top_of_beam = []
              for index, beam in enumerate(beams[:-5:-1]):
                top_of_beam.append('index: {}, decoded: {}, tokens: {}'.format(
                    index, decode_target(beam), beam))
              top_of_beams.append('\n\n'.join(top_of_beam))

          all_total_successes, all_total_denominator = per_host_sum_pmap(
              jax.tree_map(np.array, (total_successes, total_denominator)))

          # Record beam search results as text summaries.
          message = []
          for n in np.random.choice(np.arange(len(predictions)), 8):
            text = (f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n'
                    f'predicted: {predictions[n]}\n\n'
                    f'top of beam:\n\n{top_of_beams[n]}\n\n')
            message.append(text)

          # Write to tensorboard.
          if jax.host_id() == 0:
            accuracy = 100 * all_total_successes / all_total_denominator
            logging.info(
                '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                predict_or_test, step, beam_size,
                all_total_successes, all_total_denominator, accuracy,
                time.time() - t_inference_start)
            summary_writer.scalar(
                '{}/beam-size-{}'.format(predict_or_test, beam_size),
                accuracy, step)

            summary_writer.text('{}-samples-beam-{}'.format(predict_or_test,
                                                            beam_size),
                                '\n------\n'.join(message), step)
            summary_writer.flush()

    # Save a Checkpoint. Do this at the end of the training loop, so that if a
    # worker is descheduled during a round of prediction (which takes a while),
    # we will redo prediction upon restarting (to avoid losing data).
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)
Пример #12
0
 def pre_pmap(xs):
   return jax.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
Пример #13
0
def standard_train_step(
    state,
    batch,
    rng_key,
    dynamic_state,
    *,
    static_state,
    loss_fn,
    learning_rate_fn,
    model_cls,
    grad_clip=None,
    use_bfloat16=False,
    parallel=True,
    vmap_batch=False,
    ema_decay_rate=0.9,
    ema_burn_in=1000,
    threshold=0.0,
):
    """Perform a single standard training step.

  Args:
    state: a TrainState object containing the optimizer and EMA params.
    batch: dictionary or tuple
    rng_key: Jax RNG for model Dropout and additional RNG.
    dynamic_state: a dict of dynamic objects that should be passed to the model.
    static_state: any additional state to be passed to the model. The model will
      be recompiled when this changes.
    loss_fn: loss function which takes a function and batch and returns a loss.
    learning_rate_fn: function that returns the learning rate for a given
      iteration.
    model_cls: an nn.Module type to use for training. Must have a train attr.
    grad_clip: if not None, a float which determines the grad clipping norm.
    use_bfloat16: if True, round gradients to bfloat16 during training.
    parallel: if True, pmean reduces across device dimension.
    vmap_batch: if True, apply vmap over the batch axis.
    ema_decay_rate: the rate at which ema stats decay.
    ema_burn_in: the number of steps to skip before rejecting outliers.
    threshold: the probability below which any loss sample will be rejected. Set
      to 0 to disable EMA outlier rejection. Note that this currently doesn't
      work because we don't update the loss across devices properly. So please
      do not enable this.

  Returns:
    the updated optimizer, a metrics dict, and the new Jax RNG key.
  """
    logging.info('Recompiling train_step.')  # only called when recompiling

    optimizer = state.optimizer

    # We handle PRNG splitting inside the top pmap to improve efficiency.
    step = state.step
    lr = learning_rate_fn(step)

    model = model_cls(train=True)
    apply_key, rng_key = jrandom.split(rng_key)
    loss_key, rng_key = jrandom.split(rng_key)

    model_apply = utils.make_model_apply(model, apply_key)

    loss_fn = build_vmapped_loss(loss_fn,
                                 batch,
                                 loss_key,
                                 dynamic_state,
                                 is_eval=False,
                                 model_apply=model_apply,
                                 static_state=static_state,
                                 vmap_batch=vmap_batch)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (metrics, _)), grad = grad_fn(optimizer.target)

    if use_bfloat16:
        grad = jax.tree_map(lambda x: x.astype(jnp.bfloat16), grad)

    if parallel:
        grad = jax.lax.pmean(grad, 'device')

    if grad_clip is not None:
        # Clip gradients after pmean aggregation
        unclipped_grad = grad
        grad = jax.experimental.optimizers.clip_grads(grad, grad_clip)

    new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
    metrics['nn/learning_rate'] = lr

    # Gradient norms
    grad_l2_tree = l2_norm(grad)
    grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
    grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
    metrics['nn/l2_grad_sum'] = grad_l2_sum
    metrics['nn/l2_grad_max'] = grad_l2_max

    if grad_clip is not None:
        # Unclipped gradient norms (if applicable).
        grad_l2_tree = l2_norm(unclipped_grad)
        grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
        grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
        metrics['nn/l2_noclip_grad_sum'] = grad_l2_sum
        metrics['nn/l2_noclip_grad_max'] = grad_l2_max

    if threshold > 0:
        normal_pdf = jax.scipy.stats.norm.pdf(loss,
                                              loc=state.ema_loss,
                                              scale=jnp.sqrt(
                                                  state.ema_variance))
        metrics['nn/normal_pdf'] = normal_pdf
        should_replace = (normal_pdf > threshold) | (state.step < ema_burn_in)
    else:
        should_replace = True

    grads_ok = jnp.all(
        jnp.asarray([
            jnp.all(jnp.isfinite(p)) for p in jax.tree_leaves(new_optimizer)
        ]))

    loss_ok = jnp.all(jnp.isfinite(loss))
    should_replace = should_replace & grads_ok & loss_ok

    metrics['nn/step_skipped'] = 1 - should_replace
    metrics['nn/ema_loss'] = state.ema_loss
    metrics['nn/ema_variance'] = state.ema_variance
    metrics['nn/step'] = state.step
    metrics['nn/grads_ok'] = grads_ok
    metrics['nn/loss_ok'] = loss_ok

    delta = (loss - state.ema_loss)

    new_state = TrainState(
        optimizer=new_optimizer,
        step=state.step + 1,
        ema_loss=state.ema_loss * ema_decay_rate + (1 - ema_decay_rate) * loss,
        ema_variance=state.ema_variance * ema_decay_rate +
        (1 - ema_decay_rate) * delta**2,
    )

    new_state = jax.tree_multimap(
        lambda a, b: jnp.where(should_replace, a, b),
        new_state,
        state,
    )

    return new_state, metrics, rng_key
Пример #14
0
 def state_dict(self, target, state):
     state_dict = self.optimizer_def.state_dict(target, state)
     # only the first copy of the parameters and optimizer state are stored.
     state_dict = jax.tree_map(lambda x: x[0], state_dict)
     return state_dict
Пример #15
0
def unreplicate(tree):
    """Returns a single instance of a replicated array."""
    return jax.tree_map(lambda x: x[0], tree)
Пример #16
0
 def same_tree_with_value(self, tree, value):
     return jax.tree_map(lambda x: jnp.ones_like(x) * value, tree)
Пример #17
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    batch_size = config.batch_size
    n_devices = jax.device_count()
    if jax.host_count() > 1:
        raise ValueError(
            'PixelCNN++ example should not be run on more than 1 host'
            ' (for now)')
    if batch_size % n_devices > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_summary_writer, eval_summary_writer = get_summary_writers(workdir)
    # Load dataset
    data_source = input_pipeline.DataSource(config)
    train_ds = data_source.train_ds
    eval_ds = data_source.eval_ds
    steps_per_epoch = data_source.ds_info.splits[
        'train'].num_examples // config.batch_size
    # Create dataset batch iterators
    train_iter = iter(train_ds)
    num_train_steps = train_ds.cardinality().numpy()
    steps_per_checkpoint = 1000

    # Create the model using data-dependent initialization. Don't shard the init
    # batch.
    assert config.init_batch_size <= batch_size
    init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size]

    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng, dropout_rng = jax.random.split(rng, 3)

    initial_variables = model(config).init(
        {
            'params': init_rng,
            'dropout': dropout_rng
        }, init_batch)['params']
    optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995)
    optimizer = optimizer_def.create(initial_variables)

    optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables)
    ema = initial_variables
    step_offset = int(optimizer.state.step)

    optimizer, ema = jax_utils.replicate((optimizer, ema))

    # Learning rate schedule
    learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step

    # pmap the train and eval functions
    p_train_step = jax.pmap(functools.partial(train_step, config,
                                              learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, config=config),
                           axis_name='batch')

    # Gather metrics
    train_metrics = []

    for step, batch in zip(range(step_offset, num_train_steps), train_iter):
        # Load and shard the TF batch
        batch = load_and_shard_tf_batch(batch)

        # Generate a PRNG key that will be rolled into the batch.
        rng, step_rng = jax.random.split(rng)
        sharded_rngs = common_utils.shard_prng_key(step_rng)

        # Train step
        optimizer, ema, metrics = p_train_step(optimizer, ema, batch,
                                               sharded_rngs)
        train_metrics.append(metrics)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                            step)

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            # We've finished an epoch
            train_metrics = common_utils.get_metrics(train_metrics)
            # Get training epoch summary for logging
            train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
            # Send stats to Tensorboard
            for key, vals in train_metrics.items():
                for i, val in enumerate(vals):
                    train_summary_writer.scalar(key, val,
                                                step - len(vals) + i + 1)
            # Reset train metrics
            train_metrics = []

            # Evaluation
            eval_metrics = []
            for eval_batch in eval_ds:
                # Load and shard the TF batch
                eval_batch = load_and_shard_tf_batch(eval_batch)
                # Step
                metrics = p_eval_step(ema, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            # Get eval epoch summary for logging
            eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)

            # Log epoch summary
            logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch,
                         train_summary['loss'], eval_summary['loss'])

            eval_summary_writer.scalar('loss', eval_summary['loss'], step)
            train_summary_writer.flush()
            eval_summary_writer.flush()

        if (step +
                1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps:
            save_checkpoint(workdir, optimizer, ema, step)
Пример #18
0
        attn = Attn(attn_module=self.attn_module,
                    qkv_features=qkv_features // self.num_heads,
                    out_features=out_features)

        # evaluate multi-headed-attention.
        y = attn(inputs_q, inputs_kv, bias)
        return y.mean(axis=-2)


# run it.

if __name__ == '__main__':

    inputs = jnp.ones((8, 97, 256))
    rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
    model = MultiHeadDotProductAttention(
        broadcast_dropout=False,
        qkv_features=256,
        out_features=256,
        attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1),
        num_heads=8,
        batch_axes=(0, ),
    )

    y, params = model.init_with_output(rngs, inputs, inputs)

    print('input shape: ', inputs.shape)
    print('parameter shapes:')
    pprint(jax.tree_map(jnp.shape, unfreeze(params)))
    print('output shape: ', y.shape)
Пример #19
0
def to_device(xs):
  """Transfer data to devices (GPU/TPU)."""
  return jax.tree_map(jnp.array, xs)
Пример #20
0
        return x


class gMLPModel(nn.Module):
    ffn_dim: int
    model_dim: int
    num_blocks: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(name='embedding', features=self.model_dim)(x)
        for i in range(self.num_blocks):
            x = gMLPBlock(ffn_dim=self.ffn_dim, model_dim=self.model_dim)(x)
        return x


tiny_settings = {'ffn_dim': 768, 'model_dim': 128, 'num_blocks': 30}

if __name__ == "__main__":
    key = random.PRNGKey(2)
    x = random.normal(key, shape=[8, 12, 18, 100])
    model = gMLPModel(**tiny_settings)
    model_state = model.init(key, x)
    y = model.apply(model_state, x)
    print(
        json.dumps(jax.tree_map(np.shape,
                                flax.core.unfreeze(model_state['params'])),
                   indent=2))
    num_params = functools.reduce(
        operator.add, map(np.size, jax.tree_leaves(model_state['params'])))
Пример #21
0
def remove_pad(x):
    """Remove padding examples."""
    if 'mask' in x:
        ind = jnp.where(jnp.array(x.pop('mask'), dtype=jnp.int32) > 0)
        x = jax.tree_map(lambda v: v[ind], x)  # pylint: disable=cell-var-from-loop
    return x
Пример #22
0
 def _wrapped(*args):
     return jax.tree_map(pfn, *args)
Пример #23
0
  def __init__(self, model, weights, max_decode_len, beam_size=1, temperature=0,
               alpha=0.0, eos_id=None):
    """Construct an inference wrapper for an autoregressive model.

    The default behavior is to do greedy decoding:
      s = Search(model, weights, max_decode_len, eos_id=eos_id)
    Passing a temperature parameter will switch to sampling:
      s = Search(model, weights, max_decode_len, temperature=1, eos_id=eos_id)
    Passing a beam_size parameter will switch to beam search. For machine
    translation with Transformer models, Vaswani et al. (2017) recommend a beam
    size of 4 and length normalization with alpha=0.6.
      s = Search(model, weights, max_decode_len, beam_size=4, alpha=0.6,
                 eos_id=eos_id)

    After constructing the class, see Search.decode for how to decode a batch
    of examples.

    Args:
      model: function to construct a model (e.g. trax.models.Reformer)
      weights: model weights
      max_decode_len: maximum length to decode
      beam_size: beam size, for beam search
      temperature: temperature parameter for sampling; set to nonzero to switch
        from greedy/beam-search behavior to sampling.
      alpha: length penalty alpha coefficient for beam search.
      eos_id: end-of-sentence token for target vocabulary.
    """
    # TODO(kitaev): k and p parameters for top-k and nucleus sampling.
    self.model = model
    self.model_infer = model(mode='predict')
    # Weights are stored on device, but not replicated.
    self.model_weights = jax.tree_map(jax.jit(lambda x: x), weights)

    self.sample = (temperature != 0)
    self.temperature = temperature

    if self.sample and beam_size > 1:
      # TODO(kitaev): perform stochastic beam search in this case
      # (https://arxiv.org/abs/1903.06059)
      raise ValueError('beam_size parameter is not supported when sampling')

    is_cache = [isinstance(l, tl.Cache) for l in self.model_infer.sublayers]
    if any(is_cache):
      assert sum([int(x) for x in is_cache]) == 1, (
          'At most one usage of tl.Cache currently supported')
      self.encoder_idx = is_cache.index(True) + 1
    else:
      self.encoder_idx = None

    beam_search_partial = functools.partial(
        self._unreplicated_beam_search,
        beam_size=beam_size, alpha=alpha,
        eos_token=eos_id if eos_id is not None else -1,
        max_decode_len=max_decode_len + 1)  # Add 1 to account for start token.

    if trax.math.device_count() == 1:
      self._jit_beam_search = jax.jit(beam_search_partial, static_argnums=(2,))
    else:
      self._jit_beam_search = jax.pmap(beam_search_partial, axis_name='batch',
                                       static_broadcasted_argnums=(2,))

    # Work around a jax error
    # Ref: https://github.com/google/jax/issues/1919#issuecomment-569985681
    jax_partial_eval._thread_local_state.remat = True  # pylint: disable=protected-access
Пример #24
0
 def init_fn(key, sx=sx):
     return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx)
Пример #25
0
 def batch_loss(params, rng):
     stuff = jax.vmap(
         lambda rng: self.loss_and_metrics_one_pair(params, rng))(
             jax.random.split(rng, self.batch_size))
     return jax.tree_map(jnp.mean, stuff)
Пример #26
0
 def apply_fn(params, x):
     return jax.tree_map(lambda p, v: p + v, params, x)
Пример #27
0
 def init_fn(params):
     aggregate_grads = jax.tree_map(jnp.zeros_like, params)
     return TestOptimizerState(aggregate_grads, is_reset=True)
Пример #28
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ)
    # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    der_logs_ave = d_logpsi(
        jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes)
    )[0]
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        LdagL_grad,
        parameters,
    )

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )
Пример #29
0
 def _masked_sgd_on_updates(m, upd):
     return jax.tree_map(lambda x: -x, upd) if m else upd
Пример #30
0
 def init_state(self, params):
     param_states = jax.tree_map(self.init_param_state, params)
     state = OptimizerState(0, param_states)
     return state