Example #1
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    worker_fname = os.path.join(FLAGS.save_dir,
                                'program_tasks.tf_records-00000-of-00001')

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(worker_fname) as writer:
        for _ in range(FLAGS.num_tasks):
            task = sample_random.random_task(
                max_expressions=FLAGS.max_expressions,
                min_expressions=FLAGS.min_expressions,
                max_k=5,
                max_input_tokens=10,
                max_input_length=FLAGS.max_characters,
                max_output_length=FLAGS.max_characters,
                num_examples=FLAGS.num_strings_per_task,
            )
            example = serialize_example(task, token_id_table)
            writer.write(example)
Example #2
0
def main(_):
    tf.enable_v2_behavior()

    if FLAGS.seed is not None:
        tf.random.set_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    shard_id = 0
    total_shards = 1

    entire_programs_fname = os.path.join(
        FLAGS.save_dir,
        'entire_programs_{}.tf_records-{:05d}-of-{:05d}'.format(
            FLAGS.split, shard_id, total_shards))
    decomposition_data_fname = os.path.join(
        FLAGS.save_dir,
        'decomposition_data_{}.tf_records-{:05d}-of-{:05d}'.format(
            FLAGS.split, shard_id, total_shards))

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(entire_programs_fname) as entire_programs_writer, \
        tf.io.TFRecordWriter(decomposition_data_fname) as decomposition_data_writer:
        for i in range(FLAGS.num_tasks):
            if FLAGS.experiment == exp_module.Experiment.NONE.name:
                task = sample_random.random_task(
                    max_expressions=FLAGS.max_expressions,
                    min_expressions=FLAGS.min_expressions,
                    max_k=3,
                    max_input_tokens=5,
                    max_input_length=FLAGS.max_input_length,
                    num_examples=FLAGS.num_strings_per_task)
            else:
                if FLAGS.split in ['train', 'valid']:
                    is_train = True
                elif FLAGS.split == 'test':
                    is_train = False
                elif FLAGS.split == 'finetune':
                    is_train = bool(i % 2)
                else:
                    raise ValueError('Unhandled split: {}'.format(FLAGS.split))
                task = generate_task_for_experiment(FLAGS.experiment, is_train)

            entire_programs_writer.write(
                serialize_entire_program_example(task, token_id_table))
            for example in serialize_decomposition_examples(
                    task, token_id_table):
                decomposition_data_writer.write(example)
Example #3
0
    def test_decode(self):
        id_token_table, token_id_table = tokens.build_token_tables()
        self.assertEqual(len(token_id_table), len(id_token_table))
        program = dsl.Concat(
            dsl.Compose(
                dsl.Replace(' ', ','),
                dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START,
                            dsl.Type.PROP_CASE, 4, dsl.Boundary.END)),
            dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1))
        encoding = program.encode(token_id_table)
        self.assertEqual(encoding[-1], token_id_table[dsl.EOS])

        decoded_program = dsl.decode_program(encoding, id_token_table)
        self.assertEqual(
            decoded_program('Jacob Ethan James Alexander Michael'),
            'Jacob,Ethan,James,Alexander.Michael')
        self.assertEqual(decoded_program('Earth Fire Wind Water Pluto Sun'),
                         'Earth,Fire,Wind,Water.Sun')
def main(_):
    tf.enable_v2_behavior()

    if FLAGS.seed is not None:
        tf.random.set_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    worker_fname = os.path.join(
        FLAGS.save_dir,
        'program_tasks_{}.tf_records-00000-of-00001'.format(FLAGS.split))

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(worker_fname) as writer:
        for i in range(FLAGS.num_tasks):
            if FLAGS.experiment == exp_module.Experiment.NONE:
                task = sample_random.random_task(
                    max_expressions=FLAGS.max_expressions,
                    min_expressions=FLAGS.min_expressions,
                    max_k=3,
                    max_input_tokens=5,
                    max_input_length=FLAGS.max_input_length,
                    num_examples=FLAGS.num_strings_per_task)
            else:
                if FLAGS.split in ['train', 'valid']:
                    is_train = True
                elif FLAGS.split == 'test':
                    is_train = False
                elif FLAGS.split == 'finetune':
                    is_train = bool(i % 2)
                else:
                    raise ValueError('Unhandled split: {}'.format(FLAGS.split))
                task = generate_task_for_experiment(FLAGS.experiment, is_train)

            example = serialize_example(task, token_id_table)
            writer.write(example)
Example #5
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    # BOS special attention only makes sense if we are using relative attention
    # and it's not the baseline.
    if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention
                                        or FLAGS.attention_mask_type
                                        == 'baseline'):
        raise ValueError(
            "bos_special_attention doesn't work when use_relative_attention={} and "
            'attention_mask_type={}'.format(FLAGS.use_relative_attention,
                                            FLAGS.attention_mask_type))

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    hparam_str_dict = json.loads(FLAGS.xm_parameters)
    hparam_str = ','.join([
        '%s=%s' % (shorten(k), str(hparam_str_dict[k]))
        for k in hparam_str_dict.keys()
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    predict_io_shape = (FLAGS.per_device_batch_size,
                        FLAGS.num_strings_per_task,
                        FLAGS.predict_max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    # Parse io and program token sequences (for eval).
    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
        return inps, outs

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        program = program[program != bos_token]

        try:
            return dsl.decode_program(program.tolist(), id_token_table)
        except:  # pylint: disable=bare-except
            return None  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
    padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:])
    logging.info('padded_shapes: %s', padded_shapes)
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=padded_shapes,
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:],
                             program_shape[1:])
    logging.info('predict_padded_shapes: %s', predict_padded_shapes)
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)
    train_ds = dataset.skip(FLAGS.num_eval_steps)
    if FLAGS.train_set_batches > 0:
        train_ds = train_ds.take(FLAGS.train_set_batches)
    train_ds = train_ds.repeat()

    test_dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.test_dataset_filepattern, token_id_table, char_id_table)
    test_dataset = test_dataset.padded_batch(
        batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False)
    quick_test_dataset = (test_dataset.take(
        FLAGS.num_quick_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))
    final_test_dataset = (test_dataset.take(
        FLAGS.num_final_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    default_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size, output_vocab_size=program_vocab_size)
    base_config = base_models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token,
        num_input_relative_position_buckets=FLAGS.num_position_buckets,
        max_input_distance=min(FLAGS.max_distance,
                               default_config.max_input_distance),
        num_output_relative_position_buckets=FLAGS.num_position_buckets,
        max_output_distance=min(FLAGS.max_distance,
                                default_config.max_output_distance),
        num_input_cross_output_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_input_cross_output_distance=min(
            FLAGS.max_distance,
            default_config.max_input_cross_output_distance),
        num_program_relative_position_buckets=FLAGS.num_position_buckets,
        max_program_distance=min(FLAGS.max_distance,
                                 default_config.max_program_distance),
        num_program_cross_embed_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_program_cross_embed_distance=min(
            FLAGS.max_distance,
            default_config.max_program_cross_embed_distance),
        bidirectional_program_attention=FLAGS.bidirectional_program_attention)
    train_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config,
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    eval_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(deterministic=True),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)
    predict_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config.replace(
            shift=False,
            deterministic=True,
            decode=not FLAGS.slow_decode,
            max_len=max(FLAGS.max_characters, FLAGS.max_program_length,
                        FLAGS.predict_max_characters)),
        attention_mask_type=FLAGS.attention_mask_type,
        bos_special_attention=FLAGS.bos_special_attention)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    dropout_rng = jax.random.split(rng, jax.local_device_count())
    del rng

    m = models.DecomposeAttentionTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)
        if FLAGS.finetune_start_step > 0:
            logging.info(
                'Checking that start_step (%s) == finetune_start_step (%s)',
                start_step, FLAGS.finetune_start_step)
            assert start_step >= FLAGS.finetune_start_step
            steps_to_skip = start_step - FLAGS.finetune_start_step
        else:
            steps_to_skip = start_step

        # TODO(kshi): It is likely that this code can lead to the job stalling for
        # 10+ hours when restarting from a checkpoint that had been trained a long
        # time, possibly because dataset skipping is slow.
        logging.info('Skipping %s steps...', steps_to_skip)
        train_ds = train_ds.skip(steps_to_skip)
        dummy_p_train_step = jax.pmap(
            lambda dropout_rng: jax.random.split(dropout_rng)[1])
        for _ in range(steps_to_skip):
            dropout_rng = dummy_p_train_step(dropout_rng)
        logging.info('Finished skipping steps')
        logging.info('Host %s has dropout_rng = %s', jax.host_id(),
                     dropout_rng)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    # TODO(jxihong): Implement fast decoding.
    assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

    if FLAGS.finetune_start_step <= 0:
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr)
    else:
        # Constant LR for finetuning.
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr, factors='constant')
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eos_token=eos_token,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        eos_token=eos_token,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config,
        slow_decode=FLAGS.slow_decode),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, ))

    # Main Train Loop
    # ---------------------------------------------------------------------------

    logging.info('Starting training!')
    metrics_all = []
    tick = time.time()
    train_iter = train_ds.as_numpy_iterator()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, dropout_rng = p_train_step(optimizer,
                                                       inputs,
                                                       outputs,
                                                       programs,
                                                       dropout_rng=dropout_rng)
        metrics_all.append(metrics)
        is_last_step = step == FLAGS.num_train_steps - 1

        # Save a Checkpoint
        if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.

        # Training Metrics
        if (step and step % FLAGS.log_freq == 0) or is_last_step:
            logging.info('Gathering training metrics.')
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums)
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)

            if jax.host_id() == 0:
                logging.info('Train in step: %d, loss: %.4f', step,
                             summary['loss'])
                tock = time.time()
                steps_per_sec = FLAGS.log_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('train/steps per second', steps_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar('train/' + key, val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

        # Evaluation Metrics
        if (step and step % FLAGS.eval_freq == 0) or is_last_step:
            logging.info('Gathering evaluation metrics.')
            t_evaluation_start = time.time()
            eval_metrics = []
            for batches in eval_ds.as_numpy_iterator():
                inputs, outputs, programs = common_utils.shard(batches)

                metrics = p_eval_step(optimizer.target, inputs, outputs,
                                      programs)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            if jax.host_id() == 0:
                logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                             time.time() - t_evaluation_start, step,
                             eval_summary['loss'])
                for key, val in eval_summary.items():
                    summary_writer.scalar('eval/' + key, val, step)
                summary_writer.flush()

        # Beam search metrics.
        if (step and step % FLAGS.predict_freq == 0) or is_last_step:
            logging.info('Gathering beam search metrics.')
            test_ds = final_test_dataset if is_last_step else quick_test_dataset

            for dataset, predict_or_test in [(predict_ds, 'predict'),
                                             (test_ds, 'test')]:

                for beam_size in [1, 10]:
                    t_inference_start = time.time()
                    total_successes = 0
                    total_denominator = 0
                    pred_successes = collections.defaultdict(int)
                    pred_denominators = collections.defaultdict(int)

                    ios, targets, predictions, top_of_beams = [], [], [], []
                    for batches in dataset.as_numpy_iterator():
                        pred_batch = batches
                        # Handle final odd-sized batch by padding instead of dropping it.
                        cur_pred_batch_size = pred_batch[0].shape[0]
                        if cur_pred_batch_size % n_devices:
                            padded_size = int(
                                np.ceil(cur_pred_batch_size / n_devices) *
                                n_devices)
                            # pylint: disable=cell-var-from-loop
                            pred_batch = jax.tree_map(
                                lambda x: pad_examples(x, padded_size),
                                pred_batch)
                        inputs, outputs, programs = common_utils.shard(
                            pred_batch)

                        cache = (p_init_cache(inputs, outputs, programs)
                                 if not FLAGS.slow_decode else None)
                        predicted = p_pred_step(optimizer.target, inputs,
                                                outputs, cache, beam_size)
                        predicted = tohost(predicted)
                        inputs, outputs, programs = map(
                            tohost, (inputs, outputs, programs))

                        for i, beams in enumerate(predicted):
                            inps, outs = decode_io(inputs[i], outputs[i])
                            p, p_score = eval_predicted(
                                beams,
                                inps,
                                outs,
                                parse_beam_fn=decode_program)

                            # Split by length of program.
                            program = programs[i]
                            num_expressions = len(
                                decode_program(program).expressions)
                            pred_denominators[num_expressions] += 1
                            total_denominator += 1
                            if p_score >= len(inps):
                                pred_successes[num_expressions] += 1
                                total_successes += 1

                            ios.append(' ; '.join(map(str, zip(inps, outs))))
                            targets.append(
                                decode_program(programs[i]).to_string())
                            try:
                                predictions.append(p.to_string())
                            except:  # pylint: disable=bare-except
                                predictions.append('Did not compile')
                            logging.info('ios: %s', ios[-1])
                            logging.info('target: %s', targets[-1])
                            beams_log = []
                            for beam in beams:
                                try:
                                    beams_log.append(
                                        decode_program(beam).to_string())
                                except:  # pylint: disable=bare-except
                                    beams_log.append('Did not compile')
                            logging.info('predicted beam: %s',
                                         '\n'.join(beams_log))

                            top_of_beam = []
                            for index, beam in enumerate(beams[:-5:-1]):
                                try:
                                    decoded_program = decode_program(
                                        beam).to_string()
                                except:  # pylint: disable=bare-except
                                    decoded_program = 'Did not compile'
                                top_of_beam.append(
                                    'index: {}, decoded: {}, tokens: {}'.
                                    format(index, decoded_program, beam))
                            top_of_beams.append('\n\n'.join(top_of_beam))

                    all_total_successes, all_total_denominator = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (total_successes, total_denominator)))
                    all_pred_successes, all_pred_denominators = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (pred_successes, pred_denominators)))

                    # Record beam search results as text summaries.
                    message = []
                    for n in np.random.choice(np.arange(len(predictions)), 8):
                        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                                f'predicted: {predictions[n]}\n\n'
                                f'top of beam:\n\n{top_of_beams[n]}\n\n')
                        message.append(text)

                    # Write to tensorboard.
                    if jax.host_id() == 0:
                        accuracy = 100 * all_total_successes / all_total_denominator
                        logging.info(
                            '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                            predict_or_test, step, beam_size,
                            all_total_successes, all_total_denominator,
                            accuracy,
                            time.time() - t_inference_start)
                        summary_writer.scalar(
                            '{}/beam-size-{}'.format(predict_or_test,
                                                     beam_size), accuracy,
                            step)

                        for length in sorted(all_pred_successes.keys()):
                            this_length_accuracy = (
                                100 * all_pred_successes[length] /
                                all_pred_denominators[length])
                            logging.info(
                                '  accuracy for length %s: %s / %s = %.2f%%',
                                length, all_pred_successes[length],
                                all_pred_denominators[length],
                                this_length_accuracy)
                            summary_writer.scalar(
                                '{}-by-length/beam-size-{}-length-{}'.format(
                                    predict_or_test, beam_size, length),
                                this_length_accuracy, step)

                        summary_writer.text(
                            '{}-samples-beam-{}'.format(
                                predict_or_test, beam_size),
                            '\n------\n'.join(message), step)
                        summary_writer.flush()
Example #6
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
    # Get hyperparmaters
    if FLAGS.xm_parameters:
        for key, value in json.loads(FLAGS.xm_parameters).items():
            if key not in hparam_str_dict:
                hparam_str_dict[key] = value

    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
    char_id_table = {char: id for id, char in id_char_table.items()}
    id_token_table, token_id_table = dsl_tokens.build_token_tables()
    io_vocab_size = len(char_id_table) + 1  # For padding.
    program_vocab_size = len(token_id_table) + 1

    bos_token = token_id_table[dsl.BOS]
    eos_token = token_id_table[dsl.EOS]

    def decode_io(inputs, outputs):
        """Decode io examples tokens."""
        def decode_str(s):
            """Decode string tokens."""
            return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

        io_string = ''
        inps, outs = [], []
        for inp, out in zip(inputs, outputs):
            inps.append(decode_str(inp))
            outs.append(decode_str(out))
            io_string += inps[-1] + ' < ' + outs[-1] + ' > '
        return inps, outs, io_string[:-3]  # Remove last separator.

    def decode_program(program):
        """Decode program tokens."""
        program = program[:np.argmax(program == eos_token) + 1].astype(
            np.int32)
        try:
            p = dsl.decode_program(program, id_token_table)
            return p, p.to_string()
        except:  # pylint: disable=bare-except
            return None, ''  # Program does not compile.

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    dataset = input_pipeline.create_dataset_from_tf_record(
        FLAGS.dataset_filepattern, token_id_table, char_id_table)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(io_shape[1:], io_shape[1:],
                                                  program_shape[1:]),
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)),
        padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
    train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
    train_iter = train_ds.as_numpy_iterator()

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=io_vocab_size,
        output_vocab_size=program_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_token)
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(shift=False,
                                          deterministic=True,
                                          decode=True)

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    m = models.ProgramTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(program_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.lr)
    p_train_step = jax.pmap(functools.partial(
        train_lib.train_step,
        learning_rate_fn=learning_rate_fn,
        config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_lib.initialize_cache,
        max_decode_len=FLAGS.max_program_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(train_lib.predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, 5, 6))

    # Main Train Loop
    # ---------------------------------------------------------------------------
    train_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    metrics_all = []
    tick = time.time()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, programs = common_utils.shard(next(train_iter))

        optimizer, metrics, train_rngs = p_train_step(optimizer,
                                                      inputs,
                                                      outputs,
                                                      programs,
                                                      train_rng=train_rngs)
        metrics_all.append(metrics)

        # Save a Checkpoint
        if ((step % FLAGS.checkpoint_freq == 0 and step > 0)
                or step == FLAGS.num_train_steps - 1):
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if not step or step % FLAGS.log_freq != 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(
            lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
            metrics_sums)
        summary['learning_rate'] = lr
        # Calculate (clipped) perplexity after averaging log-perplexities:
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

        if jax.host_id() == 0:
            logging.info('Train in step: %d, loss: %.4f', step,
                         summary['loss'])
            tock = time.time()
            steps_per_sec = FLAGS.log_freq / (tock - tick)
            tick = tock
            summary_writer.scalar('train/steps per second', steps_per_sec,
                                  step)
            for key, val in summary.items():
                summary_writer.scalar('train/' + key, val, step)
            summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        metrics_all = []

        # Evaluation Metrics
        logging.info('Gathering evaluation metrics.')
        t_evaluation_start = time.time()
        eval_metrics = []
        for batches in eval_ds.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                         time.time() - t_evaluation_start, step,
                         eval_summary['loss'])
            for key, val in eval_summary.items():
                summary_writer.scalar('eval/' + key, val, step)
            summary_writer.flush()

        # Beam search metrics.
        logging.info('Gathering beam search metrics.')
        for beam_size in [10, 100]:
            t_inference_start = time.time()
            pred_acc = 0
            pred_denominator = 0

            ios, targets, predictions = [], [], []
            for batches in predict_ds.as_numpy_iterator():
                pred_batch = batches
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch[0].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    # pylint: disable=cell-var-from-loop
                    pred_batch = jax.tree_map(
                        lambda x: train_lib.pad_examples(x, padded_size),
                        pred_batch)
                inputs, outputs, programs = common_utils.shard(pred_batch)

                cache = p_init_cache(inputs, outputs, programs)
                predicted = p_pred_step(optimizer.target, inputs, outputs,
                                        cache, eos_token, programs.shape[-1],
                                        beam_size)
                predicted = train_lib.tohost(predicted)
                inputs, outputs, programs = map(train_lib.tohost,
                                                (inputs, outputs, programs))

                pred_denominator += programs.shape[0]
                for i, beams in enumerate(predicted):
                    inps, outs, io_string = decode_io(inputs[i], outputs[i])
                    p, p_score = train_lib.eval_predicted(
                        beams,
                        inps,
                        outs,
                        parse_beam_fn=lambda x: decode_program(x)[0])
                    if p_score >= len(inps):
                        pred_acc += 1
                    ios.append(io_string)
                    targets.append(decode_program(programs[i])[1])
                    predictions.append(p.to_string() if p else '')

            all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
                jax.tree_map(np.array, (pred_acc, pred_denominator)))

            # Record beam search results as text summaries.
            message = []
            for n in np.random.choice(np.arange(len(predictions)), 8):
                text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                        f'predicted: {predictions[n]}\n\n')
                message.append(text)

            # Write to tensorboard.
            if jax.host_id() == 0:
                logging.info(
                    'Prediction time (beam %d): %.4f s step %d, score %.4f.',
                    beam_size,
                    time.time() - t_inference_start, step,
                    all_pred_acc / all_pred_denominator)
                summary_writer.scalar('predict/score-{}'.format(beam_size),
                                      all_pred_acc / all_pred_denominator,
                                      step)
                summary_writer.text('samples-{}'.format(beam_size),
                                    '\n------\n'.join(message), step)
                summary_writer.flush()
Example #7
0
    def test_train(self):
        tf.enable_v2_behavior()

        tf.random.set_seed(0)
        np.random.seed(0)
        random.seed(0)

        dataset_filepattern = os.path.join(
            os.path.dirname(__file__),
            'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*'
        )

        print('dataset_filepattern = {}'.format(dataset_filepattern))

        batch_size = 4
        num_strings_per_task = 4
        max_characters = 10
        max_program_length = 15

        # Build token tables.
        id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)}
        char_id_table = {char: id for id, char in id_char_table.items()}
        _, token_id_table = dsl_tokens.build_token_tables()
        io_vocab_size = len(char_id_table) + 1  # For padding.
        program_vocab_size = len(token_id_table) + 1

        bos_token = token_id_table[dsl.BOS]

        # Load dataset.
        dataset = input_pipeline.create_dataset_from_tf_record(
            dataset_filepattern, token_id_table, char_id_table)
        dataset = dataset.padded_batch(batch_size,
                                       padded_shapes=((num_strings_per_task,
                                                       max_characters),
                                                      (num_strings_per_task,
                                                       max_characters),
                                                      (max_program_length, )),
                                       drop_remainder=True)
        dataset_iter = dataset.repeat().as_numpy_iterator()

        train_config = models.TransformerConfig(
            vocab_size=io_vocab_size,
            output_vocab_size=program_vocab_size,
            shift=True,
            emb_dim=32,
            num_heads=4,
            num_layers=2,
            qkv_dim=32,
            mlp_dim=32,
            max_len=max(max_characters, max_program_length),
            deterministic=False,
            decode=False,
            bos_token=bos_token)
        eval_config = train_config.replace(deterministic=True)

        rng = jax.random.PRNGKey(0)
        rng, init_rng = jax.random.split(rng)

        m = models.ProgramTransformer(eval_config)
        initial_variables = jax.jit(m.init)(
            init_rng,
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, num_strings_per_task, max_characters),
                     jnp.float32),
            jnp.ones((batch_size, max_program_length), jnp.float32))

        optimizer_def = optim.Adam(1e-2,
                                   beta1=0.9,
                                   beta2=0.98,
                                   eps=1e-9,
                                   weight_decay=0.1)
        optimizer = optimizer_def.create(initial_variables['params'])

        del initial_variables  # Don't keep a copy of the initial model.

        optimizer = jax_utils.replicate(optimizer)

        learning_rate_fn = train_lib.create_learning_rate_scheduler(
            base_learning_rate=1e-2)
        p_train_step = jax.pmap(functools.partial(
            train_lib.train_step,
            learning_rate_fn=learning_rate_fn,
            config=train_config),
                                axis_name='batch')
        p_eval_step = jax.pmap(functools.partial(train_lib.eval_step,
                                                 config=eval_config),
                               axis_name='batch')

        # Training loop.
        start_step = 0
        rngs = jax.random.split(rng, jax.local_device_count())
        del rng

        for _ in range(start_step, 1000):
            inputs, outputs, programs = common_utils.shard(next(dataset_iter))
            optimizer, _, rngs = p_train_step(optimizer,
                                              inputs,
                                              outputs,
                                              programs,
                                              train_rng=rngs)

        # Evaluation.
        eval_metrics = []
        for batches in dataset.as_numpy_iterator():
            inputs, outputs, programs = common_utils.shard(batches)

            metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
            eval_metrics.append(metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)

        if jax.host_id() == 0:
            self.assertGreater(eval_summary['accuracy'], 0.1)
Example #8
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k]))
                         for k in sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size,
                   FLAGS.num_partial_programs,
                   FLAGS.max_program_length)
  split_io_shape = (FLAGS.per_device_batch_size,
                    FLAGS.num_strings_per_task,
                    FLAGS.num_partial_programs,
                    FLAGS.max_characters)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  # Parse io and program token sequences (for eval).
  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
    return inps, outs

  def decode_program(program):
    """Decode program tokens."""
    # Concatenate all partial programs.
    full_program = []
    for p in program:
      full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32))
    full_program = np.concatenate([full_program, [eos_token]], axis=0)

    try:
      return dsl.decode_program(full_program, id_token_table)
    except:  # pylint: disable=bare-except
      return None  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern,
      token_id_table,
      char_id_table,
      num_partial_programs=FLAGS.num_partial_programs)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:],
                     split_io_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5)
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = base_models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=not FLAGS.slow_decode)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.DecomposeExpandingLayerTransformer(
      config=eval_config, num_partial_programs=FLAGS.num_partial_programs,
      use_expanding_layer=FLAGS.use_expanding_layer)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  adam_opt_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = adam_opt_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)
    if start_step > 0:
      start_step += 1

  # Build Pretraining Model and Optimizer (if specified)
  # ---------------------------------------------------------------------------
  pretrain_optimizer = None  # Optimizer used for pretrainined
  split_target = None  # Split pretrained model on partial programs.
  if start_step < FLAGS.num_pretrain_steps:
    # Load in pretraining optimizer.
    def filter_fn(path, value):
      del value
      if FLAGS.freeze_encoder and path.startswith('/encoder'):
        return False
      if FLAGS.freeze_decoder and path.startswith('/decoder'):
        return False
      return True
    trainable_weights = optim.ModelParamTraversal(filter_fn)
    pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def))
    pretrain_optimizer = pretrain_opt_def.create(optimizer.target)

    if FLAGS.pretrain_checkpoint_format:
      pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs
      checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs)

      if gfile.isdir(checkpoint_dir):
        # Use the pretrained parameters if no training has occurred yet.
        if start_step == 0:
          restore_paths = []
          if FLAGS.restore_encoder:
            restore_paths.append('target/encoder')
          if FLAGS.restore_decoder:
            restore_paths.append('target/decoder')

          pretrain_optimizer = restore_selected_paths(
              pretrain_optimizer,
              checkpoint_dir=checkpoint_dir,
              restore_paths=restore_paths)
          logging.info('Found model pretrained at %s.', checkpoint_dir)

        if FLAGS.match_split_encoding:
          split_model = models.DecomposeExpandingLayerTransformer(
              config=eval_config, num_partial_programs=1,
              use_expanding_layer=False)
          split_program_shape = (FLAGS.per_device_batch_size,
                                 1,
                                 FLAGS.max_program_length)
          split_initial_variables = jax.jit(split_model.init)(
              init_rng,
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(io_shape, jnp.float32),
              jnp.ones(split_program_shape, jnp.float32))
          split_optimizer = adam_opt_def.create(
              split_initial_variables['params'])
          split_optimizer = checkpoints.restore_checkpoint(
              checkpoint_dir, split_optimizer)
          split_target = split_optimizer.target
      else:
        logging.warn('Could not find model at %s.', checkpoint_dir)

    if FLAGS.match_split_encoding and (split_target is None):
      raise RuntimeError('We could not load the pretrained checkpoint, '
                         'which is needed to match split embeddings.')

  learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr)
  p_pretrain_step = jax.pmap(
      functools.partial(
          pretrain_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer,
          split_params=split_target),
      axis_name='batch')
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          num_partial_programs=FLAGS.num_partial_programs,
          learning_rate_fn=learning_rate_fn,
          config=train_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          num_partial_programs=FLAGS.num_partial_programs,
          eos_token=eos_token,
          config=eval_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=FLAGS.use_expanding_layer),
      axis_name='batch',
      static_broadcasted_argnums=(4,))
  p_split_pred_step = jax.pmap(
      functools.partial(
          predict_step,
          num_partial_programs=FLAGS.num_partial_programs,
          max_decode_len=FLAGS.max_program_length,
          eos_token=eos_token,
          config=predict_config,
          slow_decode=FLAGS.slow_decode,
          use_expanding_layer=False,
          use_split_encoding=True,
          split_params=split_target),
      axis_name='batch',
      static_broadcasted_argnums=(4,))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  # Replicate optimizer.
  if pretrain_optimizer:
    pretrain_optimizer = jax_utils.replicate(pretrain_optimizer)

  optimizer = jax_utils.replicate(optimizer)

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs, split_outputs = (
        common_utils.shard(next(train_iter)))

    if step < FLAGS.num_pretrain_steps:
      pretrain_optimizer, metrics, train_rngs = p_pretrain_step(
          pretrain_optimizer, inputs, outputs, programs,
          split_outputs=split_outputs,
          pretrain_rng=train_rngs)
    else:
      optimizer, metrics, train_rngs = p_train_step(
          optimizer, inputs, outputs, programs,
          train_rng=train_rngs)

    metrics_all.append(metrics)
    is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1
    is_last_step = step == FLAGS.num_train_steps - 1

    if is_last_pretrain_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)

    # Save a Checkpoint
    if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
      optimizer = maybe_copy_model_from_pretraining(
          optimizer, pretrain_optimizer, step, adam_opt_def)
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or (step % FLAGS.log_freq != 0 and not is_last_step and
                    not is_last_pretrain_step):
      continue

    optimizer = maybe_copy_model_from_pretraining(
        optimizer, pretrain_optimizer, step, adam_opt_def)

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()

    eval_summary = evaluate(
        p_eval_step=p_eval_step,
        target=optimizer.target,
        eval_ds=eval_ds)
    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [1, 10, 12, 24, 48, 96]:
      t_inference_start = time.time()

      pred_acc, message = predict_and_compute_score(
          p_pred_step=p_pred_step,
          p_init_cache=p_init_cache,
          target=optimizer.target,
          predict_ds=predict_ds,
          decode_io=decode_io,
          decode_program=decode_program,
          beam_size=beam_size,
          num_partial_programs=FLAGS.num_partial_programs,
          use_best_first_search=FLAGS.best_first_search,
          slow_decode=FLAGS.slow_decode)

      # Write to tensorboard.
      if jax.host_id() == 0:
        slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
        logging.info(
            'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f',
            slow_or_fast, beam_size, time.time() - t_inference_start, step,
            pred_acc)
        beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search'
        summary_writer.scalar(
            'predict-{}/score-{}-{}'.format(slow_or_fast,
                                            beam_search_or_bfs,
                                            beam_size),
            pred_acc, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()

      if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding:
        pred_acc, message = predict_and_compute_score(
            p_pred_step=p_split_pred_step,
            p_init_cache=p_init_cache,
            target=optimizer.target,
            predict_ds=predict_ds,
            decode_io=decode_io,
            decode_program=decode_program,
            beam_size=beam_size,
            num_partial_programs=FLAGS.num_partial_programs,
            use_best_first_search=FLAGS.best_first_search,
            slow_decode=FLAGS.slow_decode)

        # Write to tensorboard.
        if jax.host_id() == 0:
          slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast'
          beam_search_or_bfs = ('bfs' if FLAGS.best_first_search
                                else 'beam-search')
          summary_writer.scalar(
              'predict-split-{}/score-{}-{}'.format(slow_or_fast,
                                                    beam_search_or_bfs,
                                                    beam_size),
              pred_acc, step)
          summary_writer.text('samples-split-{}'.format(beam_size),
                              '\n------\n'.join(message), step)
          summary_writer.flush()
max_program_length = 100
max_characters = 200

embedding_dim = 256
hidden_dim = 512
num_heads = 4
num_layers = 3

io_shape = (per_device_batch_size,
            num_strings_per_task,
            max_characters)
program_shape = (per_device_batch_size, max_program_length)

id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
char_id_table = {char: id for id, char in id_char_table.items()}
id_token_table, token_id_table = dsl_tokens.build_token_tables()
io_vocab_size = len(char_id_table) + 1  # For padding.
program_vocab_size = len(token_id_table) + 1

bos_token = token_id_table[dsl.BOS]
eos_token = token_id_table[dsl.EOS]


DATASETS = [e.name for e in exp_module.Experiment
            if e != exp_module.Experiment.NONE]
# DATASETS = [DATASETS[0]]  # TODO(kshi): temporary

# Model-specific hyperparameters.
# (attention_mask_type, use_relative_attention, bos_special_attention)
MODELS = []
attention_mask_types = [
Example #10
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    hparam_str_dict = json.loads(FLAGS.xm_parameters)
    hparam_str = ','.join([
        '%s=%s' % (shorten(k), str(hparam_str_dict[k]))
        for k in hparam_str_dict.keys()
    ])

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    batch_size = FLAGS.per_device_batch_size * n_devices
    io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task,
                FLAGS.max_characters)
    predict_io_shape = (FLAGS.per_device_batch_size,
                        FLAGS.num_strings_per_task,
                        FLAGS.predict_max_characters)
    target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length)

    # Setup DSL
    # ---------------------------------------------------------------------------

    # Build token tables.
    if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']:
        spec_vocab = robust_fill_dsl.CHARACTER + input_pipeline.SEPARATOR_TOKEN
        spec_id_token_table = {
            i + 3: token
            for i, token in enumerate(spec_vocab)
        }
        bos_id = 1
        eos_id = 2
        spec_id_token_table[bos_id] = robust_fill_dsl.BOS
        spec_id_token_table[eos_id] = robust_fill_dsl.EOS
        spec_token_id_table = {
            token: id
            for id, token in spec_id_token_table.items()
        }
        spec_vocab_size = len(spec_token_id_table) + 1  # For padding.
        program_id_token_table, _ = dsl_tokens.build_token_tables()
        program_vocab_size = len(program_id_token_table) + 1
    elif FLAGS.dataset_type == 'scan':
        # TODO(jxihong): Scan is not handled yet.
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))
    else:
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))

    # Parse io and program token sequences (for eval).
    def decode_io(inputs, outputs):
        """Convert from int tensors to strings."""
        if FLAGS.dataset_type == 'robust_fill':

            def decode_str(s):
                """Decode string tokens."""
                return ''.join(
                    [spec_id_token_table[t_id] for t_id in s if t_id > 0])

            inps, outs = [], []
            for inp, out in zip(inputs, outputs):
                inps.append(decode_str(inp))
                outs.append(decode_str(out))
            return inps, outs

        elif FLAGS.dataset_type == 'scan':

            def decode_str(s):
                """Decode string tokens."""
                return ' '.join(
                    [spec_id_token_table[t_id] for t_id in s if t_id > 0])

            inps = [decode_str(inp) for inp in inputs]
            dummy_outs = [''] * len(inps)
            return inps, dummy_outs

        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_spec(target):
        """Convert from int tensor to a string."""
        target = target[:np.argmax(target == eos_id)].astype(np.int32)

        if FLAGS.dataset_type == 'robust_fill':
            target = target[target != bos_id].tolist()
            return ''.join(
                [spec_id_token_table[t_id] for t_id in target if t_id > 0])
        elif FLAGS.dataset_type == 'scan':
            # TODO(jxihong): Scan is not handled yet.
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))
        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_program(program):
        """Decode program tokens into a program (program object or string)."""
        program = program[:np.argmax(program == eos_id) + 1].astype(np.int32)

        if FLAGS.dataset_type == 'robust_fill':
            # Returns either a Concat program object, or None.
            program = program[program != bos_id].tolist()
            try:
                return robust_fill_dsl.decode_program(program,
                                                      program_id_token_table)
            except:  # pylint: disable=bare-except
                return None  # Program does not compile.
        elif FLAGS.dataset_type == 'scan':
            # Returns a string.
            program = program[jnp.logical_and(program != bos_id,
                                              program != eos_id)].tolist()
            return ' '.join(scan_vocab.decode(program, program_id_token_table))
        else:
            raise ValueError('Unhandled dataset_type: {}'.format(
                FLAGS.dataset_type))

    def decode_program_str(program):  # pylint: disable=unused-variable
        """Decode program tokens into a string."""
        decoded = decode_program(program)
        if FLAGS.dataset_type == 'robust_fill':
            try:
                return decoded.to_string()  # pytype: disable=attribute-error
            except:  # pylint: disable=bare-except
                return 'did not compile'
        else:
            assert isinstance(decoded,
                              str), '{} should be string'.format(decoded)
            return decoded

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    if not FLAGS.dataset_filepattern:
        raise ValueError('Must specify filepattern to dataset.')

    # Training dataset.
    logging.info('Loading dataset from %s', FLAGS.dataset_filepattern)
    padded_shapes = {
        'inputs': io_shape[1:],
        'outputs': io_shape[1:],
        'target': target_shape[1:],
    }
    logging.info('padded_shapes: %s', padded_shapes)

    if FLAGS.dataset_type == 'robust_fill':
        if FLAGS.model_type == 'spec_decomposer_model':
            create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_spec_decomposer_model
        elif FLAGS.model_type == 'synthesizer_model':
            create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_synthesizer_model
        else:
            raise ValueError(f'Unhandled model_type: {FLAGS.model_type}')

    elif FLAGS.dataset_type == 'scan':
        raise NotImplementedError()  # TODO(kshi): Implement.
        # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record
    else:
        raise ValueError('Unhandled dataset_type: {}'.format(
            FLAGS.dataset_type))

    dataset = create_dataset_fn(FLAGS.dataset_filepattern, spec_token_id_table,
                                FLAGS.num_strings_per_task)
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=padded_shapes,
                                   drop_remainder=True)
    # Split evaluation and training.
    eval_ds = dataset.take(FLAGS.num_eval_steps)
    # Decrease batch of predict dataset to handle beam search.
    predict_padded_shapes = padded_shapes.copy()
    predict_padded_shapes['inputs'] = predict_io_shape[1:]
    predict_padded_shapes['outputs'] = predict_io_shape[1:]
    logging.info('predict_padded_shapes: %s', predict_padded_shapes)
    predict_ds = eval_ds.unbatch().padded_batch(
        int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)
    train_ds = dataset.skip(FLAGS.num_eval_steps)
    if FLAGS.train_set_batches > 0:
        train_ds = train_ds.take(FLAGS.train_set_batches)
    train_ds = train_ds.repeat()

    test_dataset = create_dataset_fn(FLAGS.test_dataset_filepattern,
                                     spec_token_id_table,
                                     FLAGS.num_strings_per_task)
    test_dataset = test_dataset.padded_batch(
        batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False)
    quick_test_dataset = (test_dataset.take(
        FLAGS.num_quick_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))
    final_test_dataset = (test_dataset.take(
        FLAGS.num_final_test_steps).unbatch().padded_batch(
            int(np.ceil(batch_size / 10)),
            padded_shapes=predict_padded_shapes))

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    if FLAGS.model_type == 'spec_decomposer_model':
        output_vocab_size = spec_vocab_size
    elif FLAGS.model_type == 'synthesizer_model':
        output_vocab_size = program_vocab_size
    else:
        raise ValueError(f'Unhandled model_type: {FLAGS.model_type}')

    base_config = base_models.TransformerConfig(
        vocab_size=spec_vocab_size,
        output_vocab_size=output_vocab_size,
        shift=True,
        emb_dim=FLAGS.embedding_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.embedding_dim,
        mlp_dim=FLAGS.hidden_dim,
        max_len=max(FLAGS.max_characters, FLAGS.max_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        use_relative_attention=FLAGS.use_relative_attention,
        deterministic=False,
        decode=False,
        bos_token=bos_id,
        num_input_relative_position_buckets=FLAGS.num_position_buckets,
        max_input_distance=FLAGS.max_distance,
        num_output_relative_position_buckets=FLAGS.num_position_buckets,
        max_output_distance=FLAGS.max_distance,
        num_input_cross_output_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_input_cross_output_distance=FLAGS.max_distance,
        num_program_relative_position_buckets=FLAGS.num_position_buckets,
        max_program_distance=FLAGS.max_distance,
        num_program_cross_embed_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_program_cross_embed_distance=FLAGS.
        max_program_cross_embed_distance,
        num_flat_encoding_relative_position_buckets=(
            FLAGS.num_position_buckets),
        max_flat_encoding_distance=FLAGS.max_distance)
    train_config = models.DecomposeAttentionTransformerConfig(
        base_config=base_config,
        dataset_type=FLAGS.dataset_type,
        flat_encoded_self_attention=FLAGS.flat_encoded_self_attention)
    eval_config = train_config.replace(base_config=base_config.replace(
        deterministic=True))
    predict_config = train_config.replace(base_config=base_config.replace(
        shift=False,
        deterministic=True,
        decode=not FLAGS.slow_decode,
        max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length)))

    rng = jax.random.PRNGKey(FLAGS.seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = jax.random.split(rng)

    dropout_rng = jax.random.split(rng, jax.local_device_count())
    del rng

    m = models.DecomposeAttentionTransformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(io_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    optimizer_def = optim.Adam(FLAGS.lr,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    del initial_variables  # Don't keep a copy of the initial model.

    start_step = 0
    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)
        logging.info('Found model checkpointed at step %d.', start_step)
        if FLAGS.finetune_start_step > 0:
            logging.info(
                'Checking that start_step (%s) == finetune_start_step (%s)',
                start_step, FLAGS.finetune_start_step)
            assert start_step >= FLAGS.finetune_start_step
            steps_to_skip = start_step - FLAGS.finetune_start_step
        else:
            steps_to_skip = start_step

        # TODO(kshi): It is likely that this code can lead to the job stalling for
        # 10+ hours when restarting from a checkpoint that had been trained a long
        # time, possibly because dataset skipping is slow.
        logging.info('Skipping %s steps...', steps_to_skip)
        train_ds = train_ds.skip(steps_to_skip)
        dummy_p_train_step = jax.pmap(
            lambda dropout_rng: jax.random.split(dropout_rng)[1])
        for _ in range(steps_to_skip):
            dropout_rng = dummy_p_train_step(dropout_rng)
        logging.info('Finished skipping steps')
        logging.info('Host %s has dropout_rng = %s', jax.host_id(),
                     dropout_rng)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    # TODO(jxihong): Implement fast decoding.
    assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.'

    if FLAGS.finetune_start_step <= 0:
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr)
    else:
        # Constant LR for finetuning.
        learning_rate_fn = create_learning_rate_scheduler(
            base_learning_rate=FLAGS.lr, factors='constant')
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn, config=train_config),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step,
                                             eos_token=eos_id,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_target_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(functools.partial(
        predict_step,
        eos_token=eos_id,
        max_decode_len=FLAGS.max_target_length,
        config=predict_config,
        slow_decode=FLAGS.slow_decode),
                           axis_name='batch',
                           static_broadcasted_argnums=(4, ))

    # Main Train Loop
    # ---------------------------------------------------------------------------

    logging.info('Starting training!')
    metrics_all = []
    tick = time.time()
    train_iter = train_ds.as_numpy_iterator()
    for step in range(start_step, FLAGS.num_train_steps):
        inputs, outputs, targets = load_data(next(train_iter))

        optimizer, metrics, dropout_rng = p_train_step(optimizer,
                                                       inputs,
                                                       outputs,
                                                       targets,
                                                       dropout_rng=dropout_rng)
        metrics_all.append(metrics)
        is_last_step = step == FLAGS.num_train_steps - 1

        # Periodic metric handling.

        # Training Metrics
        if (step and step % FLAGS.log_freq == 0) or is_last_step:
            logging.info('Gathering training metrics.')
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(
                lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
                metrics_sums)
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)

            if jax.host_id() == 0:
                logging.info('Train in step: %d, loss: %.4f', step,
                             summary['loss'])
                tock = time.time()
                steps_per_sec = FLAGS.log_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('train/steps per second', steps_per_sec,
                                      step)
                for key, val in summary.items():
                    summary_writer.scalar('train/' + key, val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

        # Evaluation Metrics
        if (step and step % FLAGS.eval_freq == 0) or is_last_step:
            logging.info('Gathering evaluation metrics.')
            t_evaluation_start = time.time()
            eval_metrics = []
            for batches in eval_ds.as_numpy_iterator():
                inputs, outputs, targets = load_data(batches)

                metrics = p_eval_step(optimizer.target, inputs, outputs,
                                      targets)
                eval_metrics.append(metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            if jax.host_id() == 0:
                logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                             time.time() - t_evaluation_start, step,
                             eval_summary['loss'])
                for key, val in eval_summary.items():
                    summary_writer.scalar('eval/' + key, val, step)
                summary_writer.flush()

        # Beam search metrics.
        if (step and step % FLAGS.predict_freq == 0) or is_last_step:
            logging.info('Gathering beam search metrics.')
            test_ds = final_test_dataset if is_last_step else quick_test_dataset

            for dataset, predict_or_test in [(predict_ds, 'predict'),
                                             (test_ds, 'test')]:

                for beam_size in [1, 10]:
                    t_inference_start = time.time()
                    total_successes = 0
                    total_denominator = 0

                    ios, targets_list, predictions, top_of_beams, scores = ([],
                                                                            [],
                                                                            [],
                                                                            [],
                                                                            [])
                    for batches in dataset.as_numpy_iterator():
                        pred_batch = batches
                        # Handle final odd-sized batch by padding instead of dropping it.
                        cur_pred_batch_size = pred_batch['inputs'].shape[0]
                        if cur_pred_batch_size % n_devices:
                            padded_size = int(
                                np.ceil(cur_pred_batch_size / n_devices) *
                                n_devices)
                            # pylint: disable=cell-var-from-loop
                            pred_batch = jax.tree_map(
                                lambda x: pad_examples(x, padded_size),
                                pred_batch)
                        inputs, outputs, targets = load_data(pred_batch)

                        cache = (p_init_cache(inputs, outputs, targets)
                                 if not FLAGS.slow_decode else None)
                        predicted = p_pred_step(optimizer.target, inputs,
                                                outputs, cache, beam_size)
                        predicted = tohost(predicted)
                        inputs, outputs, targets = map(
                            tohost, (inputs, outputs, targets))

                        for i, beams in enumerate(predicted):
                            inps, outs = decode_io(inputs[i], outputs[i])

                            if FLAGS.model_type == 'spec_decomposer_model':
                                ground_truth = decode_spec(targets[i])
                                best_prediction, score = eval_predicted_spec_decomposer_model(
                                    beams, ground_truth, decode_spec)
                                decode_to_str_fn = decode_spec
                            elif FLAGS.model_type == 'synthesizer_model':
                                ground_truth = decode_program_str(targets[i])
                                best_prediction, score = eval_predicted_synthesizer_model(
                                    beams, inps, outs, decode_program)
                                decode_to_str_fn = decode_program_str
                            else:
                                raise ValueError(
                                    f'Unknown model type {FLAGS.model_type}')

                            if score > 0:
                                total_successes += 1
                            total_denominator += 1

                            beams_target = [
                                decode_to_str_fn(beam) for beam in beams
                            ]

                            ios.append(' ; '.join(map(str, zip(inps, outs))))
                            targets_list.append(ground_truth)
                            predictions.append(best_prediction)
                            scores.append(score)
                            logging.info('')
                            logging.info('ios: %s', ios[-1])
                            logging.info('targets[%s]: %s', i, targets[i])
                            logging.info('ground_truth: %s', ground_truth)
                            logging.info('predicted beam: %s',
                                         '\n'.join(beams_target))
                            logging.info('best_prediction: %s',
                                         best_prediction)
                            logging.info('score: %s', score)
                            logging.info('beams: %s', beams)

                            if not ground_truth:
                                logging.warn('ground_truth is empty!')

                            top_of_beam = []
                            for index, beam in enumerate(beams[:-5:-1]):
                                top_of_beam.append(
                                    'index: {}, decoded: {}, tokens: {}'.
                                    format(index, decode_to_str_fn(beam),
                                           beam))
                            top_of_beams.append('\n\n'.join(top_of_beam))

                    all_total_successes, all_total_denominator = per_host_sum_pmap(
                        jax.tree_map(np.array,
                                     (total_successes, total_denominator)))

                    # Record beam search results as text summaries.
                    message = []
                    for n in np.random.choice(np.arange(len(predictions)), 8):
                        text = (
                            f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n'
                            f'predicted: {predictions[n]}\n\n'
                            f'score: {scores[n]}\n\n'
                            f'top of beam:\n\n{top_of_beams[n]}\n\n')
                        message.append(text)

                    # Write to tensorboard.
                    if jax.host_id() == 0:
                        accuracy = 100 * all_total_successes / all_total_denominator
                        logging.info(
                            '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)',
                            predict_or_test, step, beam_size,
                            all_total_successes, all_total_denominator,
                            accuracy,
                            time.time() - t_inference_start)
                        summary_writer.scalar(
                            '{}/beam-size-{}'.format(predict_or_test,
                                                     beam_size), accuracy,
                            step)

                        summary_writer.text(
                            '{}-samples-beam-{}'.format(
                                predict_or_test, beam_size),
                            '\n------\n'.join(message), step)
                        summary_writer.flush()

        # Save a Checkpoint. Do this at the end of the training loop, so that if a
        # worker is descheduled during a round of prediction (which takes a while),
        # we will redo prediction upon restarting (to avoid losing data).
        if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step:
            if jax.host_id() == 0:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(
                    os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
                    jax_utils.unreplicate(optimizer), step)