def loss_fn(params, lp_params): """Loss function used for training.""" latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices, rngs={'dropout': train_rng}) latent_loss, latent_weight_sum = compute_weighted_cross_entropy( latent_logits, latent_indices, latent_weights) # End-to-end prediction. encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.decode) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) mean_loss = latent_loss / latent_weight_sum if not pretrain: mean_loss += loss / weight_sum return mean_loss, (logits, latent_logits)
def ae_loss_fn(params): """Loss function used for training autoencoder.""" (logits, vq), new_variables = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, pretrain=pretrain, mutable=['vqvae'], rngs={'dropout': train_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) # Add EOS token for latent predictor loss. vq_weight_sum = jnp.sum( jnp.where(vq['latent_indices'] > 0, 1, 0).astype(jnp.float32)) latent_indices = add_eos_token(vq['latent_indices'], eos_token) mean_loss = loss / weight_sum + vq['loss'] / vq_weight_sum return mean_loss, (new_variables['vqvae'], logits, latent_indices)
def initialize_cache(inputs, outputs, programs, max_decode_len, config, lp_config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (programs.shape[0], max_decode_len) initial_variables = models.LatentProgramTransformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(outputs.shape, config.dtype), jnp.ones(target_shape, config.dtype)) lp_initial_variables = models.ProgramTransformer(lp_config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, lp_config.dtype), jnp.ones(outputs.shape, lp_config.dtype), jnp.ones(target_shape, lp_config.dtype)) return initial_variables['cache'], lp_initial_variables['cache']
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.LatentProgramTransformer(config).apply( {'params': params, 'vqvae': state.model_state, 'cache': flat_cache}, flat_ids, flat_latents, flat_encoded, flat_latents_mask, flat_encoded_padding_mask, mutable=['cache'], method=models.LatentProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- base_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) base_eval_config = base_train_config.replace(deterministic=True, train_vq=False) base_predict_config = base_train_config.replace(shift=False, deterministic=True, train_vq=False, decode=True) train_config = models.LatentTransformerConfig( base_cfg=base_train_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) eval_config = models.LatentTransformerConfig( base_cfg=base_eval_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) predict_config = models.LatentTransformerConfig( base_cfg=base_predict_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) # Latent Predictor. lp_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=FLAGS.latent_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) lp_eval_config = lp_train_config.replace(deterministic=True) lp_predict_config = lp_train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(0) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.LatentProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) lp_m = models.ProgramTransformer(lp_eval_config) lp_initial_variables = jax.jit(lp_m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) lp_optimizer = optimizer_def.create(lp_initial_variables['params']) state = TrainState(step=0, optimizer=optimizer, model_state=initial_variables['vqvae'], lp_optimizer=lp_optimizer) # Don't keep a copy of the initial model. del initial_variables, lp_initial_variables train_rngs = jax.random.split(rng, jax.local_device_count()) start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state) # Grab last step. start_step = int(state.step) logging.info('Found model checkpointed at step %d.', start_step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_step, bos_token=bos_token, eos_token=eos_token, learning_rate_fn=learning_rate_fn, config=train_config, lp_config=lp_train_config), axis_name='batch', static_broadcasted_argnums=(4, )) p_eval_step = jax.pmap(functools.partial(eval_step, bos_token=bos_token, eos_token=eos_token, config=eval_config, lp_config=lp_eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, bos_token=bos_token, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch', static_broadcasted_argnums=(5, )) metrics_all = [] latent_metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) state, metrics, latent_metrics, train_rngs = p_train_step( state, inputs, outputs, programs, step <= FLAGS.num_pretrain_steps, train_rng=train_rngs) metrics, latent_metrics = jax.tree_map(np.array, (metrics, latent_metrics)) metrics_all.append(metrics) latent_metrics_all.append(latent_metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(state), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) latent_metrics_all = common_utils.get_metrics(latent_metrics_all) metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all) denominator = metrics_sums.pop('denominator') summary.update( jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums)) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] latent_metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] latent_eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) all_metrics = p_eval_step(state, inputs, outputs, programs) metrics, latent_metrics = jax.tree_map(np.array, all_metrics) eval_metrics.append(metrics) latent_eval_metrics.append(latent_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary.update( jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums)) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 50, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions, latent_predictions = [], [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop inputs, outputs, programs = common_utils.shard(pred_batch) cache, lp_cache = p_init_cache(inputs, outputs, programs) predicted, latent_predicted = p_pred_step( state, inputs, outputs, cache, lp_cache, beam_size) predicted, latent_predicted = map( tohost, (predicted, latent_predicted)) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_idx, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') latent_predictions.append(' '.join( list(np.array(latent_predicted[i, p_idx]).astype(str)))) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'latent_predicted: {latent_predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token, eos_token, max_decode_len, config, lp_config): """Predict translation with fast decoding beam search on a batch.""" params = state.optimizer.target lp_params = state.lp_optimizer.target # Split beam over latent sequences and programs. per_latent_beam_size = beam_size // FLAGS.latent_beam_size beam_size = FLAGS.latent_beam_size * per_latent_beam_size flat_lp_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, method=models.ProgramTransformer.encode), FLAGS.latent_beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, FLAGS.latent_beam_size) def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(lp_config).apply( { 'params': lp_params, 'cache': flat_lp_cache }, flat_ids, flat_lp_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_lp_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_lp_cache # Step 1: Beam-search over latent tokens. latent_beam_seqs, _ = decode.beam_search( inputs, lp_cache, tokens_ids_to_latent_logits, beam_size=FLAGS.latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32)) flat_latent_seqs = decode.flat_batch_beam_expand( decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size) # Quantize the predicted latent codes. flat_latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, flat_latent_seqs, mutable=False, method=models.LatentProgramTransformer.quantize) flat_encoded = decode.flat_batch_beam_expand( models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode), beam_size) # Padding masks. flat_latents_mask = jnp.where( jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token), 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state, 'cache': flat_cache }, flat_ids, flat_latents, flat_encoded, flat_latents_mask, flat_encoded_padding_mask, mutable=['cache'], method=models.LatentProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Step 2: Beam-search over program tokens. per_latent_inputs = decode.flat_batch_beam_expand(inputs, FLAGS.latent_beam_size) per_latent_cache = jax.tree_map( lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size), cache) beam_seqs, _ = decode.beam_search(per_latent_inputs, per_latent_cache, tokens_ids_to_logits, beam_size=per_latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Collapse both beam dimensions into one. beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) + beam_seqs.shape[2:]) latent_beam_seqs = jnp.repeat(latent_beam_seqs, per_latent_beam_size, axis=1) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs, latent_beam_seqs
def eval_step(state, inputs, outputs, programs, bos_token, eos_token, config, lp_config): """Evaluate on batch of program tasks.""" params = state.optimizer.target lp_params = state.lp_optimizer.target weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) # Embedding mask for autoencoding. emb_mask = jnp.ones((1, FLAGS.latent_vocab_size), jnp.float32).at[:, [0, bos_token, eos_token]].set(0) ae_logits, vq = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, mutable=False) # Postprocess latent indices. latent_indices = add_eos_token(vq['latent_indices'], eos_token) latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32) encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) # Additionally mask out eos token in latents. latents_mask = jnp.where( jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1, 0).astype(jnp.float32) latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices) encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, method=models.LatentProgramTransformer.decode) metrics = compute_metrics(logits, programs, weights) metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_')) latent_metrics = compute_metrics(latent_logits, latent_indices, latent_weights, prefix='latent_') return metrics, latent_metrics