Esempio n. 1
0
 def tokens_ids_to_logits(flat_ids, i):
   """Token slice to logits from decoder model."""
   # --> [batch * beam, 1, vocab]
   flat_logits = models.DecomposeExpandingLayerTransformer(
       config=config, num_partial_programs=num_partial_programs,
       use_expanding_layer=use_expanding_layer).apply(
           {'params': params},
           flat_ids,
           flat_encoded[:, i],
           flat_encoded_padding_mask[:, i],
           method=models.DecomposeExpandingLayerTransformer.decode)
   return flat_logits
Esempio n. 2
0
 def loss_fn(params):
   """Loss function used for training."""
   logits = models.DecomposeExpandingLayerTransformer(
       config=config, num_partial_programs=num_partial_programs,
       use_expanding_layer=use_expanding_layer).apply(
           {'params': params},
           inputs,
           outputs,
           programs,
           rngs={'dropout': train_rng})
   loss, weight_sum = compute_weighted_cross_entropy(logits, programs, weights)
   mean_loss = loss / weight_sum
   return mean_loss, logits
Esempio n. 3
0
def initialize_cache(inputs, outputs, programs, num_partial_programs,
                     max_decode_len, config, use_expanding_layer):
  """Initialize a cache for a given input shape and max decode length."""
  target_shape = programs.shape[:-1] + (max_decode_len,)

  m = models.DecomposeExpandingLayerTransformer(
      config=config, num_partial_programs=num_partial_programs,
      use_expanding_layer=use_expanding_layer)
  initial_variables = m.init(
      jax.random.PRNGKey(0),
      jnp.ones(inputs.shape, config.dtype),
      jnp.ones(outputs.shape, config.dtype),
      jnp.ones(target_shape, config.dtype))
  return initial_variables['cache']
Esempio n. 4
0
  def loss_fn(params):
    """Loss function used for training."""
    logits = models.DecomposeExpandingLayerTransformer(
        config=config, num_partial_programs=num_partial_programs,
        use_expanding_layer=use_expanding_layer).apply(
            {'params': params},
            inputs,
            outputs,
            programs,
            rngs={'dropout': pretrain_rng})
    ce_loss, weight_sum = compute_weighted_cross_entropy(
        logits, programs, weights)
    mean_ce_loss = ce_loss / weight_sum

    mean_encoded_loss = 0
    if match_split_encoding:
      encoded = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=True).apply(
              {'params': params},
              inputs,
              outputs,
              rngs={'dropout': pretrain_rng},
              method=models.DecomposeExpandingLayerTransformer.encode)
      encoded = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=True).apply(
              {'params': params},
              encoded,
              rngs={'dropout': pretrain_rng},
              method=models.DecomposeExpandingLayerTransformer.decompose)
      encoded_loss, weight_sum = compute_weighted_squared_error(
          encoded, split_encoded, split_encoded_weights)
      mean_encoded_loss = encoded_loss / weight_sum

    mean_loss = mean_ce_loss + FLAGS.alpha_encoding * mean_encoded_loss
    return mean_loss, (logits, mean_ce_loss, mean_encoded_loss)
Esempio n. 5
0
def eval_step(params, inputs, outputs, programs, num_partial_programs,
              eos_token, config, use_expanding_layer):
  """Evaluate on batch of program tasks."""
  weights = jnp.where(
      jnp.logical_and(programs > 0,
                      jnp.logical_and(programs != config.bos_token,
                                      programs != eos_token)),
      1, 0).astype(jnp.float32)

  m = models.DecomposeExpandingLayerTransformer(
      config=config, num_partial_programs=num_partial_programs,
      use_expanding_layer=use_expanding_layer)
  logits = m.apply({'params': params}, inputs, outputs, programs)

  return compute_metrics(logits, programs, weights)
Esempio n. 6
0
 def tokens_ids_to_logits(flat_ids, flat_cache, i):
   """Token slice to logits from decoder model."""
   # --> [batch * beam, 1, vocab]
   flat_logits, new_vars = models.DecomposeExpandingLayerTransformer(
       config=config, num_partial_programs=num_partial_programs,
       use_expanding_layer=use_expanding_layer).apply(
           {'params': params, 'cache': flat_cache},
           flat_ids,
           flat_encoded[:, i],
           flat_encoded_padding_mask[:, i],
           mutable=['cache'],
           method=models.DecomposeExpandingLayerTransformer.decode)
   new_flat_cache = new_vars['cache']
   # Remove singleton sequence-length dimension:
   # [batch * beam, 1, vocab] --> [batch * beam, vocab]
   flat_logits = flat_logits.squeeze(axis=1)
   return flat_logits, new_flat_cache
Esempio n. 7
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_partial_programs,
                   FLAGS.max_program_length)
  split_io_shape = (FLAGS.per_device_batch_size,
                    FLAGS.num_strings_per_task,
                    FLAGS.num_partial_programs,
                    FLAGS.max_characters)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
    return inps, outs

  def decode_program(program):
    """Decode program tokens."""
    # Concatenate all partial programs.
    full_program = []
    for p in program:
      full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32))
    full_program = np.concatenate([full_program, [eos_token]], axis=0)

    try:
      return dsl.decode_program(full_program, id_token_table)
    except:  # pylint: disable=bare-except
      return None  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern,
      token_id_table,
      char_id_table,
      num_partial_programs=FLAGS.num_partial_programs)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5)
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = base_models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=not FLAGS.slow_decode)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.DecomposeExpandingLayerTransformer(
      config=eval_config, num_partial_programs=FLAGS.num_partial_programs,
      use_expanding_layer=FLAGS.use_expanding_layer)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  adam_opt_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = adam_opt_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if start_step > 0:
      start_step += 1

  # Build Pretraining Model and Optimizer (if specified)
  # ---------------------------------------------------------------------------
  pretrain_optimizer = None  # Optimizer used for pretrainined
  split_target = None  # Split pretrained model on partial programs.
  if start_step < FLAGS.num_pretrain_steps:
    # Load in pretraining optimizer.
    def filter_fn(path, value):
      del value
      if FLAGS.freeze_encoder and path.startswith('/encoder'):
        return False
      if FLAGS.freeze_decoder and path.startswith('/decoder'):
        return False
      return True
    trainable_weights = optim.ModelParamTraversal(filter_fn)
    pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def))
    pretrain_optimizer = pretrain_opt_def.create(optimizer.target)

    if FLAGS.pretrain_checkpoint_format:
      pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs
      checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs)

      if gfile.isdir(checkpoint_dir):
        # Use the pretrained parameters if no training has occurred yet.
        if start_step == 0:
          restore_paths = []
          if FLAGS.restore_encoder:
            restore_paths.append('target/encoder')
          if FLAGS.restore_decoder:
            restore_paths.append('target/decoder')

          pretrain_optimizer = restore_selected_paths(
              pretrain_optimizer,
              checkpoint_dir=checkpoint_dir,
              restore_paths=restore_paths)
          logging.info('Found model pretrained at %s.', checkpoint_dir)

        if FLAGS.match_split_encoding:
          split_model = models.DecomposeExpandingLayerTransformer(
              config=eval_config, num_partial_programs=1,
              use_expanding_layer=False)
          split_program_shape = (FLAGS.per_device_batch_size,
                                 1,
                                 FLAGS.max_program_length)
          split_initial_variables = jax.jit(split_model.init)(
              init_rng,
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(split_program_shape, jnp.float32))
          split_optimizer = adam_opt_def.create(
              split_initial_variables['params'])
          split_optimizer = checkpoints.restore_checkpoint(
              checkpoint_dir, split_optimizer)
          split_target = split_optimizer.target
      else:
        logging.warn('Could not find model at %s.', checkpoint_dir)

    if FLAGS.match_split_encoding and (split_target is None):
      raise RuntimeError('We could not load the pretrained checkpoint, '
                         'which is needed to match split embeddings.')

  learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr)
  p_pretrain_step = jax.pmap(
      functools.partial(
          pretrain_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer,
          split_params=split_target),
      axis_name='batch')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          num_partial_programs=FLAGS.num_partial_programs,
          eos_token=eos_token,
          config=eval_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_split_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=False,
          use_split_encoding=True,
          split_params=split_target),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  # Replicate optimizer.
  if pretrain_optimizer:
    pretrain_optimizer = jax_utils.replicate(pretrain_optimizer)

  optimizer = jax_utils.replicate(optimizer)

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs, split_outputs = (
        common_utils.shard(next(train_iter)))

    if step < FLAGS.num_pretrain_steps:
      pretrain_optimizer, metrics, train_rngs = p_pretrain_step(
          pretrain_optimizer, inputs, outputs, programs,
          split_outputs=split_outputs,
          pretrain_rng=train_rngs)
    else:
      optimizer, metrics, train_rngs = p_train_step(
          optimizer, inputs, outputs, programs,
          train_rng=train_rngs)

    metrics_all.append(metrics)
    is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1
    is_last_step = step == FLAGS.num_train_steps - 1

    if is_last_pretrain_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)

    # Save a Checkpoint
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or (step % FLAGS.log_freq != 0 and not is_last_step and
                    not is_last_pretrain_step):
      continue

    optimizer = maybe_copy_model_from_pretraining(
        optimizer, pretrain_optimizer, step, adam_opt_def)

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()

    eval_summary = evaluate(
        p_eval_step=p_eval_step,
        target=optimizer.target,
        eval_ds=eval_ds)
    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [1, 10, 12, 24, 48, 96]:
      t_inference_start = time.time()

      pred_acc, message = predict_and_compute_score(
          p_pred_step=p_pred_step,
          p_init_cache=p_init_cache,
          target=optimizer.target,
          predict_ds=predict_ds,
          decode_io=decode_io,
          decode_program=decode_program,
          beam_size=beam_size,
          num_partial_programs=FLAGS.num_partial_programs,
          use_best_first_search=FLAGS.best_first_search,
          slow_decode=FLAGS.slow_decode)

      # Write to tensorboard.
      if jax.host_id() == 0:
        slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
        logging.info(
            'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f',
            slow_or_fast, beam_size, time.time() - t_inference_start, step,
            pred_acc)
        beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search'
        summary_writer.scalar(
            'predict-{}/score-{}-{}'.format(slow_or_fast,
                                            beam_search_or_bfs,
                                            beam_size),
            pred_acc, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()

      if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding:
        pred_acc, message = predict_and_compute_score(
            p_pred_step=p_split_pred_step,
            p_init_cache=p_init_cache,
            target=optimizer.target,
            predict_ds=predict_ds,
            decode_io=decode_io,
            decode_program=decode_program,
            beam_size=beam_size,
            num_partial_programs=FLAGS.num_partial_programs,
            use_best_first_search=FLAGS.best_first_search,
            slow_decode=FLAGS.slow_decode)

        # Write to tensorboard.
        if jax.host_id() == 0:
          slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
          beam_search_or_bfs = ('bfs' if FLAGS.best_first_search
                                else 'beam-search')
          summary_writer.scalar(
              'predict-split-{}/score-{}-{}'.format(slow_or_fast,
                                                    beam_search_or_bfs,
                                                    beam_size),
              pred_acc, step)
          summary_writer.text('samples-split-{}'.format(beam_size),
                              '\n------\n'.join(message), step)
          summary_writer.flush()
Esempio n. 8
0
def predict_step(params,
                 inputs,
                 outputs,
                 cache,
                 beam_size,
                 num_partial_programs,
                 max_decode_len,
                 eos_token,
                 config,
                 use_expanding_layer,
                 slow_decode=False,
                 use_split_encoding=False,
                 split_params=None,
                 split_outputs=None):
  """Predict translation with fast decoding beam search on a batch."""
  per_partial_beam_size = max(beam_size // num_partial_programs, 1)

  m = models.DecomposeExpandingLayerTransformer(
      config=config, num_partial_programs=num_partial_programs,
      use_expanding_layer=use_expanding_layer)
  # Prepare transformer fast-decoder call for beam search: for beam search, we
  # need to set up our decoder model to handle a batch size equal to
  # batch_size * beam_size, where each batch item's data is expanded in-place
  # rather than tiled.
  if use_split_encoding:
    # Use pretrained split model to get encodings
    assert (split_params is not None) and (split_outputs is not None)

    split_inputs = models.add_and_tile_dim(inputs, num_partial_programs, axis=1)
    # split_outputs shape == [batch_size, num_partial, num_io, length]
    split_outputs = jnp.swapaxes(split_outputs, 1, 2)
    encoded = models.DecomposeExpandingLayerTransformer(
        config=config, num_partial_programs=1,
        use_expanding_layer=False).apply(
            {'params': split_params},
            split_inputs,
            split_outputs,
            method=models.DecomposeExpandingLayerTransformer.encode)
    flat_encoded = decode.flat_batch_beam_expand(encoded, per_partial_beam_size)

    encoded_padding_mask = jnp.where(
        split_outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, per_partial_beam_size)
  else:
    flat_encoded = decode.flat_batch_beam_expand(
        m.apply(
            {'params': params},
            inputs,
            outputs,
            method=models.DecomposeExpandingLayerTransformer.encode),
        per_partial_beam_size)
    flat_encoded = m.apply(
        {'params': params},
        flat_encoded,
        method=models.DecomposeExpandingLayerTransformer.decompose)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, per_partial_beam_size)
    flat_encoded_padding_mask = models.add_and_tile_dim(
        flat_encoded_padding_mask, num_partial_programs, axis=1)

  if slow_decode:
    def tokens_ids_to_logits(flat_ids, i):
      """Token slice to logits from decoder model."""
      # --> [batch * beam, 1, vocab]
      flat_logits = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=use_expanding_layer).apply(
              {'params': params},
              flat_ids,
              flat_encoded[:, i],
              flat_encoded_padding_mask[:, i],
              method=models.DecomposeExpandingLayerTransformer.decode)
      return flat_logits
  else:
    def tokens_ids_to_logits(flat_ids, flat_cache, i):
      """Token slice to logits from decoder model."""
      # --> [batch * beam, 1, vocab]
      flat_logits, new_vars = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=use_expanding_layer).apply(
              {'params': params, 'cache': flat_cache},
              flat_ids,
              flat_encoded[:, i],
              flat_encoded_padding_mask[:, i],
              mutable=['cache'],
              method=models.DecomposeExpandingLayerTransformer.decode)
      new_flat_cache = new_vars['cache']
      # Remove singleton sequence-length dimension:
      # [batch * beam, 1, vocab] --> [batch * beam, vocab]
      flat_logits = flat_logits.squeeze(axis=1)
      return flat_logits, new_flat_cache

  # Perform beam search independently for each partial program.
  all_beam_seqs = []
  all_beam_log_probs = []
  for i in range(num_partial_programs):
    beam_seqs, beam_log_probs = decode.beam_search(
        inputs,
        cache,
        functools.partial(tokens_ids_to_logits, i=i),
        beam_size=per_partial_beam_size,
        alpha=0.6,
        bos_token=config.bos_token,
        eos_token=eos_token,
        max_decode_len=max_decode_len,
        slow_decode=slow_decode)
    all_beam_seqs.append(beam_seqs)
    all_beam_log_probs.append(beam_log_probs)

  all_beam_seqs = jnp.stack(all_beam_seqs, axis=1)
  all_beam_log_probs = jnp.stack(all_beam_log_probs, axis=1)

  # all_beam_seqs shape == [batch, n_partial, n_beam_per_partial, length]
  assert len(all_beam_seqs.shape) == 4
  # all_beam_log_probs shape == [batch, n_partial, n_beam_per_partial]
  assert len(all_beam_log_probs.shape) == 3

  # Sort beams in order of decreasing probability.
  order = jnp.argsort(all_beam_log_probs, axis=2)[:, :, ::-1]
  all_beam_log_probs = jnp.take_along_axis(all_beam_log_probs, order, axis=2)
  all_beam_seqs = jnp.take_along_axis(all_beam_seqs, order[Ellipsis, jnp.newaxis],
                                      axis=2)

  return all_beam_seqs, all_beam_log_probs
Esempio n. 9
0
def pretrain_step(optimizer,
                  inputs,
                  outputs,
                  programs,
                  num_partial_programs,
                  learning_rate_fn,
                  config,
                  use_expanding_layer,
                  split_params=None,  # Pretrained split model parameters.
                  split_outputs=None,  # Outputs split by partial program.
                  pretrain_rng=None):
  """Pretrain on batch of program tasks."""
  pretrain_rng, new_pretrain_rng = jax.random.split(pretrain_rng)

  weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)

  match_split_encoding = ((split_params is not None) and
                          (split_outputs is not None))
  if match_split_encoding:
    # Get the i/o encodings of pretrained split model.
    split_inputs = models.add_and_tile_dim(inputs, num_partial_programs, axis=1)
    # split_outputs shape == [batch_size, num_partial, num_io, length]
    split_outputs = jnp.swapaxes(split_outputs, 1, 2)
    split_encoded = models.DecomposeExpandingLayerTransformer(
        config=config.replace(deterministic=True), num_partial_programs=1,
        use_expanding_layer=False).apply(
            {'params': split_params},
            split_inputs,
            split_outputs,
            method=models.DecomposeExpandingLayerTransformer.encode)
    split_encoded_weights = jnp.where(
        split_outputs > 0, 1, 0).astype(jnp.float32)

  def loss_fn(params):
    """Loss function used for training."""
    logits = models.DecomposeExpandingLayerTransformer(
        config=config, num_partial_programs=num_partial_programs,
        use_expanding_layer=use_expanding_layer).apply(
            {'params': params},
            inputs,
            outputs,
            programs,
            rngs={'dropout': pretrain_rng})
    ce_loss, weight_sum = compute_weighted_cross_entropy(
        logits, programs, weights)
    mean_ce_loss = ce_loss / weight_sum

    mean_encoded_loss = 0
    if match_split_encoding:
      encoded = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=True).apply(
              {'params': params},
              inputs,
              outputs,
              rngs={'dropout': pretrain_rng},
              method=models.DecomposeExpandingLayerTransformer.encode)
      encoded = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=True).apply(
              {'params': params},
              encoded,
              rngs={'dropout': pretrain_rng},
              method=models.DecomposeExpandingLayerTransformer.decompose)
      encoded_loss, weight_sum = compute_weighted_squared_error(
          encoded, split_encoded, split_encoded_weights)
      mean_encoded_loss = encoded_loss / weight_sum

    mean_loss = mean_ce_loss + FLAGS.alpha_encoding * mean_encoded_loss
    return mean_loss, (logits, mean_ce_loss, mean_encoded_loss)

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, _, encoded_loss)), grad = grad_fn(optimizer.target)
  grad = jax.lax.pmean(grad, 'batch')
  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)

  # Get metrics.
  metrics = compute_metrics(logits, programs, weights)
  metrics['learning_rate'] = lr
  if match_split_encoding:
    metrics['encoded_loss'] = encoded_loss

  return new_optimizer, metrics, new_pretrain_rng