def initialize_cache(inputs, outputs, targets, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (targets.shape[0], max_decode_len) dtype = config.base_config.dtype initial_variables = models.DecomposeAttentionTransformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, dtype), jnp.ones(outputs.shape, dtype), jnp.ones(target_shape, dtype)) return initial_variables['cache']
def tokens_ids_to_logits(flat_ids): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeAttentionTransformer(config=config).apply( {'params': params}, flat_ids, flat_encoded, flat_encoded_padding_mask, method=models.DecomposeAttentionTransformer.decode) return flat_logits
def eval_step(params, inputs, outputs, targets, eos_token, config): """Collect metrics for evaluation during training.""" weights = jnp.where( jnp.logical_and( targets > 0, jnp.logical_and(targets != config.base_config.bos_token, targets != eos_token)), 1, 0).astype(jnp.float32) logits = models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, targets) return compute_metrics(logits, targets, weights)
def loss_fn(params): """Loss function used for training.""" logits = models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, targets, rngs={'dropout': dropout_rng}) loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights) mean_loss = loss / weight_sum return mean_loss, logits
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeAttentionTransformer( config=config).apply( {'params': params, 'cache': flat_cache}, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.DecomposeAttentionTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']: spec_vocab = robust_fill_dsl.CHARACTER + input_pipeline.SEPARATOR_TOKEN spec_id_token_table = { i + 3: token for i, token in enumerate(spec_vocab) } bos_id = 1 eos_id = 2 spec_id_token_table[bos_id] = robust_fill_dsl.BOS spec_id_token_table[eos_id] = robust_fill_dsl.EOS spec_token_id_table = { token: id for id, token in spec_id_token_table.items() } spec_vocab_size = len(spec_token_id_table) + 1 # For padding. program_id_token_table, _ = dsl_tokens.build_token_tables() program_vocab_size = len(program_id_token_table) + 1 elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Convert from int tensors to strings.""" if FLAGS.dataset_type == 'robust_fill': def decode_str(s): """Decode string tokens.""" return ''.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs elif FLAGS.dataset_type == 'scan': def decode_str(s): """Decode string tokens.""" return ' '.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps = [decode_str(inp) for inp in inputs] dummy_outs = [''] * len(inps) return inps, dummy_outs else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_spec(target): """Convert from int tensor to a string.""" target = target[:np.argmax(target == eos_id)].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': target = target[target != bos_id].tolist() return ''.join( [spec_id_token_table[t_id] for t_id in target if t_id > 0]) elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program(program): """Decode program tokens into a program (program object or string).""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program_str(program): # pylint: disable=unused-variable """Decode program tokens into a string.""" decoded = decode_program(program) if FLAGS.dataset_type == 'robust_fill': try: return decoded.to_string() # pytype: disable=attribute-error except: # pylint: disable=bare-except return 'did not compile' else: assert isinstance(decoded, str), '{} should be string'.format(decoded) return decoded # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = { 'inputs': io_shape[1:], 'outputs': io_shape[1:], 'target': target_shape[1:], } logging.info('padded_shapes: %s', padded_shapes) if FLAGS.dataset_type == 'robust_fill': if FLAGS.model_type == 'spec_decomposer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_spec_decomposer_model elif FLAGS.model_type == 'synthesizer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_synthesizer_model else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') elif FLAGS.dataset_type == 'scan': raise NotImplementedError() # TODO(kshi): Implement. # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) dataset = create_dataset_fn(FLAGS.dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = padded_shapes.copy() predict_padded_shapes['inputs'] = predict_io_shape[1:] predict_padded_shapes['outputs'] = predict_io_shape[1:] logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = create_dataset_fn(FLAGS.test_dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- if FLAGS.model_type == 'spec_decomposer_model': output_vocab_size = spec_vocab_size elif FLAGS.model_type == 'synthesizer_model': output_vocab_size = program_vocab_size else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') base_config = base_models.TransformerConfig( vocab_size=spec_vocab_size, output_vocab_size=output_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_id, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=FLAGS.max_distance, num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=FLAGS.max_distance, num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=FLAGS.max_distance, num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=FLAGS.max_distance, num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=FLAGS. max_program_cross_embed_distance, num_flat_encoding_relative_position_buckets=( FLAGS.num_position_buckets), max_flat_encoding_distance=FLAGS.max_distance) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, dataset_type=FLAGS.dataset_type, flat_encoded_self_attention=FLAGS.flat_encoded_self_attention) eval_config = train_config.replace(base_config=base_config.replace( deterministic=True)) predict_config = train_config.replace(base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length))) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_id, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_target_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_id, max_decode_len=FLAGS.max_target_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, targets = load_data(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, targets, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, targets = load_data(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, targets) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 ios, targets_list, predictions, top_of_beams, scores = ([], [], [], [], []) for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, targets = load_data(pred_batch) cache = (p_init_cache(inputs, outputs, targets) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, targets = map( tohost, (inputs, outputs, targets)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) if FLAGS.model_type == 'spec_decomposer_model': ground_truth = decode_spec(targets[i]) best_prediction, score = eval_predicted_spec_decomposer_model( beams, ground_truth, decode_spec) decode_to_str_fn = decode_spec elif FLAGS.model_type == 'synthesizer_model': ground_truth = decode_program_str(targets[i]) best_prediction, score = eval_predicted_synthesizer_model( beams, inps, outs, decode_program) decode_to_str_fn = decode_program_str else: raise ValueError( f'Unknown model type {FLAGS.model_type}') if score > 0: total_successes += 1 total_denominator += 1 beams_target = [ decode_to_str_fn(beam) for beam in beams ] ios.append(' ; '.join(map(str, zip(inps, outs)))) targets_list.append(ground_truth) predictions.append(best_prediction) scores.append(score) logging.info('') logging.info('ios: %s', ios[-1]) logging.info('targets[%s]: %s', i, targets[i]) logging.info('ground_truth: %s', ground_truth) logging.info('predicted beam: %s', '\n'.join(beams_target)) logging.info('best_prediction: %s', best_prediction) logging.info('score: %s', score) logging.info('beams: %s', beams) if not ground_truth: logging.warn('ground_truth is empty!') top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decode_to_str_fn(beam), beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = ( f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'score: {scores[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush() # Save a Checkpoint. Do this at the end of the training loop, so that if a # worker is descheduled during a round of prediction (which takes a while), # we will redo prediction upon restarting (to avoid losing data). if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step)
def predict_step(params, inputs, outputs, cache, beam_size, eos_token, max_decode_len, config, slow_decode=True): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. flat_encoded = decode.flat_batch_beam_expand( models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, method=models.DecomposeAttentionTransformer.encode), beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) if slow_decode: def tokens_ids_to_logits(flat_ids): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeAttentionTransformer( config=config).apply( {'params': params}, flat_ids, flat_encoded, flat_encoded_padding_mask, method=models.DecomposeAttentionTransformer.decode) return flat_logits else: def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeAttentionTransformer( config=config).apply( { 'params': params, 'cache': flat_cache }, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.DecomposeAttentionTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search(inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, bos_token=config.base_config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len, slow_decode=slow_decode) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs