def load_model(dataset_name, attention_mask_type, use_relative_attention, bos_special_attention, predict_config): """Loads a checkpoint.""" rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.DecomposeAttentionTransformer(predict_config) initial_variables = jax.jit(m.init)({ 'params': init_rng, 'dropout': init_rng }, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(1e-3, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(initial_variables['params']) checkpoint_fname = os.path.join( FLAGS.train_directory, 'train-{}/checkpoints/' 'amt={},bsa={},ed=256,hd=512,l=0.001,nh=4,nl=3,s=0,ura={}/'.format( dataset_name, attention_mask_type, bos_special_attention, use_relative_attention)) logging.info('Loading checkpoint: %s', checkpoint_fname) optimizer = checkpoints.restore_checkpoint(checkpoint_fname, optimizer) checkpoint_num_trained_steps = int(optimizer.state.step) logging.info('Found model checkpointed at step %s.', checkpoint_num_trained_steps) optimizer = jax_utils.replicate(optimizer) return optimizer
def initialize_cache(inputs, outputs, programs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (programs.shape[0], max_decode_len) dtype = config.base_config.dtype initial_variables = models.DecomposeAttentionTransformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, dtype), jnp.ones(outputs.shape, dtype), jnp.ones(target_shape, dtype)) return initial_variables['cache']
def eval_step(params, inputs, outputs, programs, eos_token, config): weights = jnp.where( jnp.logical_and( programs > 0, jnp.logical_and(programs != config.bos_token, programs != eos_token)), 1, 0).astype(jnp.float32) logits = models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, programs) return compute_metrics(logits, programs, weights)
def tokens_ids_to_logits(flat_ids): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeAttentionTransformer( config=config).apply( {'params': params}, flat_ids, flat_encoded, flat_encoded_padding_mask, method=models.DecomposeAttentionTransformer.decode) return flat_logits
def loss_fn(params): """Loss function used for training.""" logits = models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, programs, rngs={'dropout': dropout_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) mean_loss = loss / weight_sum return mean_loss, logits
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeAttentionTransformer( config=config).apply( { 'params': params, 'cache': flat_cache }, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.DecomposeAttentionTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def eval_step(params, inputs, outputs, programs, eos_token, config, dropout_rng=None): """Collect metrics for evaluation during training.""" # This code is necessary to experiment with using dropout during evaluation, # but we don't normally use dropout here. dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) weights = jnp.where( jnp.logical_and( programs > 0, jnp.logical_and(programs != config.base_config.bos_token, programs != eos_token)), 1, 0).astype(jnp.float32) logits = models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, programs, rngs={'dropout': dropout_rng}) return compute_metrics(logits, programs, weights), new_dropout_rng
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) # BOS special attention only makes sense if we are using relative attention # and it's not the baseline. if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention or FLAGS.attention_mask_type == 'baseline'): raise ValueError( "bos_special_attention doesn't work when use_relative_attention={} and " 'attention_mask_type={}'.format(FLAGS.use_relative_attention, FLAGS.attention_mask_type)) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program.tolist(), id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:]) logging.info('padded_shapes: %s', padded_shapes) dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:], program_shape[1:]) logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.test_dataset_filepattern, token_id_table, char_id_table) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- default_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size) base_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=min(FLAGS.max_distance, default_config.max_input_distance), num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=min(FLAGS.max_distance, default_config.max_output_distance), num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=min( FLAGS.max_distance, default_config.max_input_cross_output_distance), num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=min(FLAGS.max_distance, default_config.max_program_distance), num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=min( FLAGS.max_distance, default_config.max_program_cross_embed_distance), bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.max_characters, FLAGS.max_program_length, FLAGS.predict_max_characters)), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, programs, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 pred_successes = collections.defaultdict(int) pred_denominators = collections.defaultdict(int) ios, targets, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard( pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map( tohost, (inputs, outputs, programs)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=decode_program) # Split by length of program. program = programs[i] num_expressions = len( decode_program(program).expressions) pred_denominators[num_expressions] += 1 total_denominator += 1 if p_score >= len(inps): pred_successes[num_expressions] += 1 total_successes += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append( decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('Did not compile') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append( decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('Did not compile') logging.info('predicted beam: %s', '\n'.join(beams_log)) top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): try: decoded_program = decode_program( beam).to_string() except: # pylint: disable=bare-except decoded_program = 'Did not compile' top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decoded_program, beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) all_pred_successes, all_pred_denominators = per_host_sum_pmap( jax.tree_map(np.array, (pred_successes, pred_denominators))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) for length in sorted(all_pred_successes.keys()): this_length_accuracy = ( 100 * all_pred_successes[length] / all_pred_denominators[length]) logging.info( ' accuracy for length %s: %s / %s = %.2f%%', length, all_pred_successes[length], all_pred_denominators[length], this_length_accuracy) summary_writer.scalar( '{}-by-length/beam-size-{}-length-{}'.format( predict_or_test, beam_size, length), this_length_accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def predict_step(params, inputs, outputs, cache, beam_size, eos_token, max_decode_len, config, slow_decode=True): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. flat_encoded = decode.flat_batch_beam_expand( models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, method=models.DecomposeAttentionTransformer.encode), beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) if slow_decode: def tokens_ids_to_logits(flat_ids): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeAttentionTransformer( config=config).apply( {'params': params}, flat_ids, flat_encoded, flat_encoded_padding_mask, method=models.DecomposeAttentionTransformer.decode) return flat_logits else: def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeAttentionTransformer( config=config).apply( { 'params': params, 'cache': flat_cache }, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.DecomposeAttentionTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search(inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, bos_token=config.base_config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len, slow_decode=slow_decode) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step): continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted(beams, inps, outs, parse_beam_fn=decode_program) if p_score >= len(inps): pred_acc += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append(decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append(decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('None') logging.info('predicted beam: %s', '\n'.join(beams_log)) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar( 'predict-{}/score-{}'.format(slow_or_fast, beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()