Exemple #1
0
def initialize_cache(inputs, outputs, programs, max_decode_len, config):
    """Initialize a cache for a given input shape and max decode length."""
    target_shape = (programs.shape[0], max_decode_len)
    initial_variables = models.ProgramTransformer(config).init(
        jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype),
        jnp.ones(outputs.shape, config.dtype),
        jnp.ones(target_shape, config.dtype))
    return initial_variables['cache']
Exemple #2
0
def predict_step(params,
                 inputs,
                 outputs,
                 cache,
                 eos_token,
                 max_decode_len,
                 beam_size,
                 config):
  """Predict translation with fast decoding beam search on a batch."""
  # Prepare transformer fast-decoder call for beam search: for beam search, we
  # need to set up our decoder model to handle a batch size equal to
  # batch_size * beam_size, where each batch item's data is expanded in-place
  # rather than tiled.
  flat_encoded = decode.flat_batch_beam_expand(
      models.ProgramTransformer(config).apply(
          {'params': params},
          inputs,
          outputs,
          method=models.ProgramTransformer.encode),
      beam_size)

  encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
  flat_encoded_padding_mask = decode.flat_batch_beam_expand(
      encoded_padding_mask, beam_size)

  def tokens_ids_to_logits(flat_ids, flat_cache):
    """Token slice to logits from decoder model."""
    # --> [batch * beam, 1, vocab]
    flat_logits, new_vars = models.ProgramTransformer(config).apply(
        {'params': params, 'cache': flat_cache},
        flat_ids,
        flat_encoded,
        flat_encoded_padding_mask,
        mutable=['cache'],
        method=models.ProgramTransformer.decode)
    new_flat_cache = new_vars['cache']
    # Remove singleton sequence-length dimension:
    # [batch * beam, 1, vocab] --> [batch * beam, vocab]
    flat_logits = flat_logits.squeeze(axis=1)
    return flat_logits, new_flat_cache

  # Using the above-defined single-step decoder function, run a
  # beam search over possible sequences given input encoding.
  beam_seqs, _ = decode.beam_search(
      inputs,
      cache,
      tokens_ids_to_logits,
      beam_size=beam_size,
      alpha=0.6,
      bos_token=config.bos_token,
      eos_token=eos_token,
      max_decode_len=max_decode_len)

  # Beam search returns [n_batch, n_beam, n_length] with beam dimension
  # sorted in increasing order of log-probability.
  return beam_seqs
Exemple #3
0
 def loss_fn(params):
   """Loss function used for training."""
   logits = models.ProgramTransformer(config).apply(
       {'params': params},
       inputs,
       outputs,
       programs,
       rngs={'dropout': train_rng})
   loss, weight_sum = compute_weighted_cross_entropy(logits, programs, weights)
   mean_loss = loss / weight_sum
   return mean_loss, logits
    def loss_fn(params, lp_params):
        """Loss function used for training."""
        latent_logits = models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            latent_indices,
            rngs={'dropout': train_rng})
        latent_loss, latent_weight_sum = compute_weighted_cross_entropy(
            latent_logits, latent_indices, latent_weights)

        # End-to-end prediction.
        encoded = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.encode)
        latents = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            latent_logits,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.quantize)
        logits = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            programs,
            latents,
            encoded,
            latents_mask,
            encoded_mask,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.decode)
        loss, weight_sum = compute_weighted_cross_entropy(
            logits, programs, weights)

        mean_loss = latent_loss / latent_weight_sum
        if not pretrain:
            mean_loss += loss / weight_sum
        return mean_loss, (logits, latent_logits)
 def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache):
   """Token slice to logits from decoder model."""
   # --> [batch * beam, 1, vocab]
   flat_logits, new_vars = models.ProgramTransformer(lp_config).apply(
       {'params': lp_params, 'cache': flat_lp_cache},
       flat_ids,
       flat_lp_encoded,
       flat_encoded_padding_mask,
       mutable=['cache'],
       method=models.ProgramTransformer.decode)
   new_flat_lp_cache = new_vars['cache']
   # Remove singleton sequence-length dimension:
   # [batch * beam, 1, vocab] --> [batch * beam, vocab]
   flat_logits = flat_logits.squeeze(axis=1)
   return flat_logits, new_flat_lp_cache
Exemple #6
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          decode=True)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.ProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_lib.train_step,
        learning_rate_fn=learning_rate_fn,
        config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_lib.initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(train_lib.predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, 5, 6))

    # Main Train Loop
    # ---------------------------------------------------------------------------
    train_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, train_rngs = p_train_step(optimizer,
                                                      inputs,
                                                      outputs,
                                                      programs,
                                                      train_rng=train_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f', step,
                         summary['loss'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions = [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    # pylint: disable=cell-var-from-loop
                    pred_batch = jax.tree_map(
                        lambda x: train_lib.pad_examples(x, padded_size),
                        pred_batch)
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache = p_init_cache(inputs, outputs, programs)
                predicted = p_pred_step(optimizer.target, inputs, outputs,
                                        cache, eos_token, programs.shape[-1],
                                        beam_size)
                predicted = train_lib.tohost(predicted)
                inputs, outputs, programs = map(train_lib.tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_score = train_lib.eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')

            all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()
Exemple #7
0
    def test_train(self):
        tf.enable_v2_behavior()

        tf.random.set_seed(0)
        np.random.seed(0)
        random.seed(0)

        dataset_filepattern = os.path.join(
            os.path.dirname(__file__),
            'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*'
        )

        print('dataset_filepattern = {}'.format(dataset_filepattern))

        batch_size = 4
        num_strings_per_task = 4
        max_characters = 10
        max_program_length = 15

        # Build token tables.
        id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
        char_id_table = {char: id for id, char in id_char_table.items()}
        _, token_id_table = dsl_tokens.build_token_tables()
        io_vocab_size = len(char_id_table) + 1  # For padding.
        program_vocab_size = len(token_id_table) + 1

        bos_token = token_id_table[dsl.BOS]

        # Load dataset.
        dataset = input_pipeline.create_dataset_from_tf_record(
            dataset_filepattern, token_id_table, char_id_table)
        dataset = dataset.padded_batch(batch_size,
                                       padded_shapes=((num_strings_per_task,
                                                       max_characters),
                                                      (num_strings_per_task,
                                                       max_characters),
                                                      (max_program_length, )),
                                       drop_remainder=True)
        dataset_iter = dataset.repeat().as_numpy_iterator()

        train_config = models.TransformerConfig(
            vocab_size=io_vocab_size,
            output_vocab_size=program_vocab_size,
            shift=True,
            emb_dim=32,
            num_heads=4,
            num_layers=2,
            qkv_dim=32,
            mlp_dim=32,
            max_len=max(max_characters, max_program_length),
            deterministic=False,
            decode=False,
            bos_token=bos_token)
        eval_config = train_config.replace(deterministic=True)

        rng = jax.random.PRNGKey(0)
        rng, init_rng = jax.random.split(rng)

        m = models.ProgramTransformer(eval_config)
        initial_variables = jax.jit(m.init)(
            init_rng,
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, max_program_length), jnp.float32))

        optimizer_def = optim.Adam(1e-2,
                                   beta1=0.9,
                                   beta2=0.98,
                                   eps=1e-9,
                                   weight_decay=0.1)
        optimizer = optimizer_def.create(initial_variables['params'])

        del initial_variables  # Don't keep a copy of the initial model.

        optimizer = jax_utils.replicate(optimizer)

        learning_rate_fn = train_lib.create_learning_rate_scheduler(
            base_learning_rate=1e-2)
        p_train_step = jax.pmap(functools.partial(
            train_lib.train_step,
            learning_rate_fn=learning_rate_fn,
            config=train_config),
                                axis_name='batch')
        p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                                 config=eval_config),
                               axis_name='batch')

        # Training loop.
        start_step = 0
        rngs = jax.random.split(rng, jax.local_device_count())
        del rng

        for _ in range(start_step, 1000):
            inputs, outputs, programs = common_utils.shard(next(dataset_iter))
            optimizer, _, rngs = p_train_step(optimizer,
                                              inputs,
                                              outputs,
                                              programs,
                                              train_rng=rngs)

        # Evaluation.
        eval_metrics = []
        for batches in dataset.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            self.assertGreater(eval_summary['accuracy'], 0.1)
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token,
                 eos_token, max_decode_len, config, lp_config):
    """Predict translation with fast decoding beam search on a batch."""
    params = state.optimizer.target
    lp_params = state.lp_optimizer.target

    # Split beam over latent sequences and programs.
    per_latent_beam_size = beam_size // FLAGS.latent_beam_size
    beam_size = FLAGS.latent_beam_size * per_latent_beam_size

    flat_lp_encoded = decode.flat_batch_beam_expand(
        models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            method=models.ProgramTransformer.encode), FLAGS.latent_beam_size)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, FLAGS.latent_beam_size)

    def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.ProgramTransformer(lp_config).apply(
            {
                'params': lp_params,
                'cache': flat_lp_cache
            },
            flat_ids,
            flat_lp_encoded,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.ProgramTransformer.decode)
        new_flat_lp_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_lp_cache

    # Step 1: Beam-search over latent tokens.
    latent_beam_seqs, _ = decode.beam_search(
        inputs,
        lp_cache,
        tokens_ids_to_latent_logits,
        beam_size=FLAGS.latent_beam_size,
        alpha=0.6,
        bos_token=bos_token,
        eos_token=eos_token,
        max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32))

    flat_latent_seqs = decode.flat_batch_beam_expand(
        decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size)
    # Quantize the predicted latent codes.
    flat_latents = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        flat_latent_seqs,
        mutable=False,
        method=models.LatentProgramTransformer.quantize)

    flat_encoded = decode.flat_batch_beam_expand(
        models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            method=models.LatentProgramTransformer.encode), beam_size)

    # Padding masks.
    flat_latents_mask = jnp.where(
        jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token),
        1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state,
                'cache': flat_cache
            },
            flat_ids,
            flat_latents,
            flat_encoded,
            flat_latents_mask,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.LatentProgramTransformer.decode)
        new_flat_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_cache

    # Step 2: Beam-search over program tokens.
    per_latent_inputs = decode.flat_batch_beam_expand(inputs,
                                                      FLAGS.latent_beam_size)
    per_latent_cache = jax.tree_map(
        lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size),
        cache)
    beam_seqs, _ = decode.beam_search(per_latent_inputs,
                                      per_latent_cache,
                                      tokens_ids_to_logits,
                                      beam_size=per_latent_beam_size,
                                      alpha=0.6,
                                      bos_token=bos_token,
                                      eos_token=eos_token,
                                      max_decode_len=max_decode_len)
    # Collapse both beam dimensions into one.
    beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) +
                                  beam_seqs.shape[2:])
    latent_beam_seqs = jnp.repeat(latent_beam_seqs,
                                  per_latent_beam_size,
                                  axis=1)

    # Beam search returns [n_batch, n_beam, n_length] with beam dimension
    # sorted in increasing order of log-probability.
    return beam_seqs, latent_beam_seqs
def eval_step(state, inputs, outputs, programs, bos_token, eos_token, config,
              lp_config):
    """Evaluate on batch of program tasks."""
    params = state.optimizer.target
    lp_params = state.lp_optimizer.target

    weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)
    # Embedding mask for autoencoding.
    emb_mask = jnp.ones((1, FLAGS.latent_vocab_size),
                        jnp.float32).at[:, [0, bos_token, eos_token]].set(0)

    ae_logits, vq = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        inputs,
        outputs,
        programs,
        emb_mask,
        mutable=False)

    # Postprocess latent indices.
    latent_indices = add_eos_token(vq['latent_indices'], eos_token)
    latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32)

    encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    # Additionally mask out eos token in latents.
    latents_mask = jnp.where(
        jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1,
        0).astype(jnp.float32)

    latent_logits = models.ProgramTransformer(lp_config).apply(
        {'params': lp_params}, inputs, outputs, latent_indices)

    encoded = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        inputs,
        outputs,
        mutable=False,
        method=models.LatentProgramTransformer.encode)
    latents = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        latent_logits,
        mutable=False,
        method=models.LatentProgramTransformer.quantize)
    logits = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        programs,
        latents,
        encoded,
        latents_mask,
        encoded_mask,
        mutable=False,
        method=models.LatentProgramTransformer.decode)

    metrics = compute_metrics(logits, programs, weights)
    metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_'))
    latent_metrics = compute_metrics(latent_logits,
                                     latent_indices,
                                     latent_weights,
                                     prefix='latent_')
    return metrics, latent_metrics
Exemple #10
0
def eval_step(params, inputs, outputs, programs, config):
    weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)
    logits = models.ProgramTransformer(config).apply({'params': params},
                                                     inputs, outputs, programs)

    return compute_metrics(logits, programs, weights)