def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) _, token_id_table = dsl_tokens.build_token_tables() if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) worker_fname = os.path.join(FLAGS.save_dir, 'program_tasks.tf_records-00000-of-00001') # Write the `tf.Example` observations to the file. with tf.io.TFRecordWriter(worker_fname) as writer: for _ in range(FLAGS.num_tasks): task = sample_random.random_task( max_expressions=FLAGS.max_expressions, min_expressions=FLAGS.min_expressions, max_k=5, max_input_tokens=10, max_input_length=FLAGS.max_characters, max_output_length=FLAGS.max_characters, num_examples=FLAGS.num_strings_per_task, ) example = serialize_example(task, token_id_table) writer.write(example)
def main(_): tf.enable_v2_behavior() if FLAGS.seed is not None: tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) _, token_id_table = dsl_tokens.build_token_tables() if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) shard_id = 0 total_shards = 1 entire_programs_fname = os.path.join( FLAGS.save_dir, 'entire_programs_{}.tf_records-{:05d}-of-{:05d}'.format( FLAGS.split, shard_id, total_shards)) decomposition_data_fname = os.path.join( FLAGS.save_dir, 'decomposition_data_{}.tf_records-{:05d}-of-{:05d}'.format( FLAGS.split, shard_id, total_shards)) # Write the `tf.Example` observations to the file. with tf.io.TFRecordWriter(entire_programs_fname) as entire_programs_writer, \ tf.io.TFRecordWriter(decomposition_data_fname) as decomposition_data_writer: for i in range(FLAGS.num_tasks): if FLAGS.experiment == exp_module.Experiment.NONE.name: task = sample_random.random_task( max_expressions=FLAGS.max_expressions, min_expressions=FLAGS.min_expressions, max_k=3, max_input_tokens=5, max_input_length=FLAGS.max_input_length, num_examples=FLAGS.num_strings_per_task) else: if FLAGS.split in ['train', 'valid']: is_train = True elif FLAGS.split == 'test': is_train = False elif FLAGS.split == 'finetune': is_train = bool(i % 2) else: raise ValueError('Unhandled split: {}'.format(FLAGS.split)) task = generate_task_for_experiment(FLAGS.experiment, is_train) entire_programs_writer.write( serialize_entire_program_example(task, token_id_table)) for example in serialize_decomposition_examples( task, token_id_table): decomposition_data_writer.write(example)
def test_decode(self): id_token_table, token_id_table = tokens.build_token_tables() self.assertEqual(len(token_id_table), len(id_token_table)) program = dsl.Concat( dsl.Compose( dsl.Replace(' ', ','), dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START, dsl.Type.PROP_CASE, 4, dsl.Boundary.END)), dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1)) encoding = program.encode(token_id_table) self.assertEqual(encoding[-1], token_id_table[dsl.EOS]) decoded_program = dsl.decode_program(encoding, id_token_table) self.assertEqual( decoded_program('Jacob Ethan James Alexander Michael'), 'Jacob,Ethan,James,Alexander.Michael') self.assertEqual(decoded_program('Earth Fire Wind Water Pluto Sun'), 'Earth,Fire,Wind,Water.Sun')
def main(_): tf.enable_v2_behavior() if FLAGS.seed is not None: tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) _, token_id_table = dsl_tokens.build_token_tables() if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) worker_fname = os.path.join( FLAGS.save_dir, 'program_tasks_{}.tf_records-00000-of-00001'.format(FLAGS.split)) # Write the `tf.Example` observations to the file. with tf.io.TFRecordWriter(worker_fname) as writer: for i in range(FLAGS.num_tasks): if FLAGS.experiment == exp_module.Experiment.NONE: task = sample_random.random_task( max_expressions=FLAGS.max_expressions, min_expressions=FLAGS.min_expressions, max_k=3, max_input_tokens=5, max_input_length=FLAGS.max_input_length, num_examples=FLAGS.num_strings_per_task) else: if FLAGS.split in ['train', 'valid']: is_train = True elif FLAGS.split == 'test': is_train = False elif FLAGS.split == 'finetune': is_train = bool(i % 2) else: raise ValueError('Unhandled split: {}'.format(FLAGS.split)) task = generate_task_for_experiment(FLAGS.experiment, is_train) example = serialize_example(task, token_id_table) writer.write(example)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) # BOS special attention only makes sense if we are using relative attention # and it's not the baseline. if FLAGS.bos_special_attention and (not FLAGS.use_relative_attention or FLAGS.attention_mask_type == 'baseline'): raise ValueError( "bos_special_attention doesn't work when use_relative_attention={} and " 'attention_mask_type={}'.format(FLAGS.use_relative_attention, FLAGS.attention_mask_type)) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program.tolist(), id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = (io_shape[1:], io_shape[1:], program_shape[1:]) logging.info('padded_shapes: %s', padded_shapes) dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = (predict_io_shape[1:], predict_io_shape[1:], program_shape[1:]) logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.test_dataset_filepattern, token_id_table, char_id_table) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- default_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size) base_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=min(FLAGS.max_distance, default_config.max_input_distance), num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=min(FLAGS.max_distance, default_config.max_output_distance), num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=min( FLAGS.max_distance, default_config.max_input_cross_output_distance), num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=min(FLAGS.max_distance, default_config.max_program_distance), num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=min( FLAGS.max_distance, default_config.max_program_cross_embed_distance), bidirectional_program_attention=FLAGS.bidirectional_program_attention) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) eval_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace(deterministic=True), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) predict_config = models.DecomposeAttentionTransformerConfig( base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.max_characters, FLAGS.max_program_length, FLAGS.predict_max_characters)), attention_mask_type=FLAGS.attention_mask_type, bos_special_attention=FLAGS.bos_special_attention) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, programs, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 pred_successes = collections.defaultdict(int) pred_denominators = collections.defaultdict(int) ios, targets, predictions, top_of_beams = [], [], [], [] for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard( pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map( tohost, (inputs, outputs, programs)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=decode_program) # Split by length of program. program = programs[i] num_expressions = len( decode_program(program).expressions) pred_denominators[num_expressions] += 1 total_denominator += 1 if p_score >= len(inps): pred_successes[num_expressions] += 1 total_successes += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append( decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('Did not compile') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append( decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('Did not compile') logging.info('predicted beam: %s', '\n'.join(beams_log)) top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): try: decoded_program = decode_program( beam).to_string() except: # pylint: disable=bare-except decoded_program = 'Did not compile' top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decoded_program, beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) all_pred_successes, all_pred_denominators = per_host_sum_pmap( jax.tree_map(np.array, (pred_successes, pred_denominators))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) for length in sorted(all_pred_successes.keys()): this_length_accuracy = ( 100 * all_pred_successes[length] / all_pred_denominators[length]) logging.info( ' accuracy for length %s: %s / %s = %.2f%%', length, all_pred_successes[length], all_pred_denominators[length], this_length_accuracy) summary_writer.scalar( '{}-by-length/beam-size-{}-length-{}'.format( predict_or_test, beam_size, length), this_length_accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def test_train(self): tf.enable_v2_behavior() tf.random.set_seed(0) np.random.seed(0) random.seed(0) dataset_filepattern = os.path.join( os.path.dirname(__file__), 'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*' ) print('dataset_filepattern = {}'.format(dataset_filepattern)) batch_size = 4 num_strings_per_task = 4 max_characters = 10 max_program_length = 15 # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} _, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] # Load dataset. dataset = input_pipeline.create_dataset_from_tf_record( dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=((num_strings_per_task, max_characters), (num_strings_per_task, max_characters), (max_program_length, )), drop_remainder=True) dataset_iter = dataset.repeat().as_numpy_iterator() train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=32, num_heads=4, num_layers=2, qkv_dim=32, mlp_dim=32, max_len=max(max_characters, max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, max_program_length), jnp.float32)) optimizer_def = optim.Adam(1e-2, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=0.1) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=1e-2) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') # Training loop. start_step = 0 rngs = jax.random.split(rng, jax.local_device_count()) del rng for _ in range(start_step, 1000): inputs, outputs, programs = common_utils.shard(next(dataset_iter)) optimizer, _, rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=rngs) # Evaluation. eval_metrics = [] for batches in dataset.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: self.assertGreater(eval_summary['accuracy'], 0.1)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.num_partial_programs, FLAGS.max_program_length) split_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.num_partial_programs, FLAGS.max_characters) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table, num_partial_programs=FLAGS.num_partial_programs) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5) train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=FLAGS.num_partial_programs, use_expanding_layer=FLAGS.use_expanding_layer) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) adam_opt_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = adam_opt_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if start_step > 0: start_step += 1 # Build Pretraining Model and Optimizer (if specified) # --------------------------------------------------------------------------- pretrain_optimizer = None # Optimizer used for pretrainined split_target = None # Split pretrained model on partial programs. if start_step < FLAGS.num_pretrain_steps: # Load in pretraining optimizer. def filter_fn(path, value): del value if FLAGS.freeze_encoder and path.startswith('/encoder'): return False if FLAGS.freeze_decoder and path.startswith('/decoder'): return False return True trainable_weights = optim.ModelParamTraversal(filter_fn) pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def)) pretrain_optimizer = pretrain_opt_def.create(optimizer.target) if FLAGS.pretrain_checkpoint_format: pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs) if gfile.isdir(checkpoint_dir): # Use the pretrained parameters if no training has occurred yet. if start_step == 0: restore_paths = [] if FLAGS.restore_encoder: restore_paths.append('target/encoder') if FLAGS.restore_decoder: restore_paths.append('target/decoder') pretrain_optimizer = restore_selected_paths( pretrain_optimizer, checkpoint_dir=checkpoint_dir, restore_paths=restore_paths) logging.info('Found model pretrained at %s.', checkpoint_dir) if FLAGS.match_split_encoding: split_model = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=1, use_expanding_layer=False) split_program_shape = (FLAGS.per_device_batch_size, 1, FLAGS.max_program_length) split_initial_variables = jax.jit(split_model.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(split_program_shape, jnp.float32)) split_optimizer = adam_opt_def.create( split_initial_variables['params']) split_optimizer = checkpoints.restore_checkpoint( checkpoint_dir, split_optimizer) split_target = split_optimizer.target else: logging.warn('Could not find model at %s.', checkpoint_dir) if FLAGS.match_split_encoding and (split_target is None): raise RuntimeError('We could not load the pretrained checkpoint, ' 'which is needed to match split embeddings.') learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr) p_pretrain_step = jax.pmap( functools.partial( pretrain_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer, split_params=split_target), axis_name='batch') p_train_step = jax.pmap( functools.partial( train_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, num_partial_programs=FLAGS.num_partial_programs, eos_token=eos_token, config=eval_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, config=predict_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch', static_broadcasted_argnums=(4,)) p_split_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=False, use_split_encoding=True, split_params=split_target), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng # Replicate optimizer. if pretrain_optimizer: pretrain_optimizer = jax_utils.replicate(pretrain_optimizer) optimizer = jax_utils.replicate(optimizer) metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs, split_outputs = ( common_utils.shard(next(train_iter))) if step < FLAGS.num_pretrain_steps: pretrain_optimizer, metrics, train_rngs = p_pretrain_step( pretrain_optimizer, inputs, outputs, programs, split_outputs=split_outputs, pretrain_rng=train_rngs) else: optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1 is_last_step = step == FLAGS.num_train_steps - 1 if is_last_pretrain_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step and not is_last_pretrain_step): continue optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_summary = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc, message = predict_and_compute_score( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, pred_acc) beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search' summary_writer.scalar( 'predict-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush() if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding: pred_acc, message = predict_and_compute_score( p_pred_step=p_split_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' beam_search_or_bfs = ('bfs' if FLAGS.best_first_search else 'beam-search') summary_writer.scalar( 'predict-split-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-split-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
max_program_length = 100 max_characters = 200 embedding_dim = 256 hidden_dim = 512 num_heads = 4 num_layers = 3 io_shape = (per_device_batch_size, num_strings_per_task, max_characters) program_shape = (per_device_batch_size, max_program_length) id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] DATASETS = [e.name for e in exp_module.Experiment if e != exp_module.Experiment.NONE] # DATASETS = [DATASETS[0]] # TODO(kshi): temporary # Model-specific hyperparameters. # (attention_mask_type, use_relative_attention, bos_special_attention) MODELS = [] attention_mask_types = [
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.makedirs(FLAGS.save_dir) hparam_str_dict = json.loads(FLAGS.xm_parameters) hparam_str = ','.join([ '%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in hparam_str_dict.keys() ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) predict_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.predict_max_characters) target_shape = (FLAGS.per_device_batch_size, FLAGS.max_target_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. if FLAGS.dataset_type in ['robust_fill', 'robust_fill_base']: spec_vocab = robust_fill_dsl.CHARACTER + input_pipeline.SEPARATOR_TOKEN spec_id_token_table = { i + 3: token for i, token in enumerate(spec_vocab) } bos_id = 1 eos_id = 2 spec_id_token_table[bos_id] = robust_fill_dsl.BOS spec_id_token_table[eos_id] = robust_fill_dsl.EOS spec_token_id_table = { token: id for id, token in spec_id_token_table.items() } spec_vocab_size = len(spec_token_id_table) + 1 # For padding. program_id_token_table, _ = dsl_tokens.build_token_tables() program_vocab_size = len(program_id_token_table) + 1 elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Convert from int tensors to strings.""" if FLAGS.dataset_type == 'robust_fill': def decode_str(s): """Decode string tokens.""" return ''.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs elif FLAGS.dataset_type == 'scan': def decode_str(s): """Decode string tokens.""" return ' '.join( [spec_id_token_table[t_id] for t_id in s if t_id > 0]) inps = [decode_str(inp) for inp in inputs] dummy_outs = [''] * len(inps) return inps, dummy_outs else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_spec(target): """Convert from int tensor to a string.""" target = target[:np.argmax(target == eos_id)].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': target = target[target != bos_id].tolist() return ''.join( [spec_id_token_table[t_id] for t_id in target if t_id > 0]) elif FLAGS.dataset_type == 'scan': # TODO(jxihong): Scan is not handled yet. raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program(program): """Decode program tokens into a program (program object or string).""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) def decode_program_str(program): # pylint: disable=unused-variable """Decode program tokens into a string.""" decoded = decode_program(program) if FLAGS.dataset_type == 'robust_fill': try: return decoded.to_string() # pytype: disable=attribute-error except: # pylint: disable=bare-except return 'did not compile' else: assert isinstance(decoded, str), '{} should be string'.format(decoded) return decoded # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. logging.info('Loading dataset from %s', FLAGS.dataset_filepattern) padded_shapes = { 'inputs': io_shape[1:], 'outputs': io_shape[1:], 'target': target_shape[1:], } logging.info('padded_shapes: %s', padded_shapes) if FLAGS.dataset_type == 'robust_fill': if FLAGS.model_type == 'spec_decomposer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_spec_decomposer_model elif FLAGS.model_type == 'synthesizer_model': create_dataset_fn = input_pipeline.create_robust_fill_dataset_for_synthesizer_model else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') elif FLAGS.dataset_type == 'scan': raise NotImplementedError() # TODO(kshi): Implement. # create_dataset_fn = input_pipeline.create_scan_dataset_from_tf_record else: raise ValueError('Unhandled dataset_type: {}'.format( FLAGS.dataset_type)) dataset = create_dataset_fn(FLAGS.dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes, drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_padded_shapes = padded_shapes.copy() predict_padded_shapes['inputs'] = predict_io_shape[1:] predict_padded_shapes['outputs'] = predict_io_shape[1:] logging.info('predict_padded_shapes: %s', predict_padded_shapes) predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes) train_ds = dataset.skip(FLAGS.num_eval_steps) if FLAGS.train_set_batches > 0: train_ds = train_ds.take(FLAGS.train_set_batches) train_ds = train_ds.repeat() test_dataset = create_dataset_fn(FLAGS.test_dataset_filepattern, spec_token_id_table, FLAGS.num_strings_per_task) test_dataset = test_dataset.padded_batch( batch_size, padded_shapes=predict_padded_shapes, drop_remainder=False) quick_test_dataset = (test_dataset.take( FLAGS.num_quick_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) final_test_dataset = (test_dataset.take( FLAGS.num_final_test_steps).unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=predict_padded_shapes)) # Build Model and Optimizer # --------------------------------------------------------------------------- if FLAGS.model_type == 'spec_decomposer_model': output_vocab_size = spec_vocab_size elif FLAGS.model_type == 'synthesizer_model': output_vocab_size = program_vocab_size else: raise ValueError(f'Unhandled model_type: {FLAGS.model_type}') base_config = base_models.TransformerConfig( vocab_size=spec_vocab_size, output_vocab_size=output_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_id, num_input_relative_position_buckets=FLAGS.num_position_buckets, max_input_distance=FLAGS.max_distance, num_output_relative_position_buckets=FLAGS.num_position_buckets, max_output_distance=FLAGS.max_distance, num_input_cross_output_relative_position_buckets=( FLAGS.num_position_buckets), max_input_cross_output_distance=FLAGS.max_distance, num_program_relative_position_buckets=FLAGS.num_position_buckets, max_program_distance=FLAGS.max_distance, num_program_cross_embed_relative_position_buckets=( FLAGS.num_position_buckets), max_program_cross_embed_distance=FLAGS. max_program_cross_embed_distance, num_flat_encoding_relative_position_buckets=( FLAGS.num_position_buckets), max_flat_encoding_distance=FLAGS.max_distance) train_config = models.DecomposeAttentionTransformerConfig( base_config=base_config, dataset_type=FLAGS.dataset_type, flat_encoded_self_attention=FLAGS.flat_encoded_self_attention) eval_config = train_config.replace(base_config=base_config.replace( deterministic=True)) predict_config = train_config.replace(base_config=base_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode, max_len=max(FLAGS.predict_max_characters, FLAGS.max_target_length))) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rng = jax.random.split(rng, jax.local_device_count()) del rng m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if FLAGS.finetune_start_step > 0: logging.info( 'Checking that start_step (%s) == finetune_start_step (%s)', start_step, FLAGS.finetune_start_step) assert start_step >= FLAGS.finetune_start_step steps_to_skip = start_step - FLAGS.finetune_start_step else: steps_to_skip = start_step # TODO(kshi): It is likely that this code can lead to the job stalling for # 10+ hours when restarting from a checkpoint that had been trained a long # time, possibly because dataset skipping is slow. logging.info('Skipping %s steps...', steps_to_skip) train_ds = train_ds.skip(steps_to_skip) dummy_p_train_step = jax.pmap( lambda dropout_rng: jax.random.split(dropout_rng)[1]) for _ in range(steps_to_skip): dropout_rng = dummy_p_train_step(dropout_rng) logging.info('Finished skipping steps') logging.info('Host %s has dropout_rng = %s', jax.host_id(), dropout_rng) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' if FLAGS.finetune_start_step <= 0: learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) else: # Constant LR for finetuning. learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr, factors='constant') p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_id, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_target_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_id, max_decode_len=FLAGS.max_target_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- logging.info('Starting training!') metrics_all = [] tick = time.time() train_iter = train_ds.as_numpy_iterator() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, targets = load_data(next(train_iter)) optimizer, metrics, dropout_rng = p_train_step(optimizer, inputs, outputs, targets, dropout_rng=dropout_rng) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Periodic metric handling. # Training Metrics if (step and step % FLAGS.log_freq == 0) or is_last_step: logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics if (step and step % FLAGS.eval_freq == 0) or is_last_step: logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, targets = load_data(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, targets) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. if (step and step % FLAGS.predict_freq == 0) or is_last_step: logging.info('Gathering beam search metrics.') test_ds = final_test_dataset if is_last_step else quick_test_dataset for dataset, predict_or_test in [(predict_ds, 'predict'), (test_ds, 'test')]: for beam_size in [1, 10]: t_inference_start = time.time() total_successes = 0 total_denominator = 0 ios, targets_list, predictions, top_of_beams, scores = ([], [], [], [], []) for batches in dataset.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, targets = load_data(pred_batch) cache = (p_init_cache(inputs, outputs, targets) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, targets = map( tohost, (inputs, outputs, targets)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) if FLAGS.model_type == 'spec_decomposer_model': ground_truth = decode_spec(targets[i]) best_prediction, score = eval_predicted_spec_decomposer_model( beams, ground_truth, decode_spec) decode_to_str_fn = decode_spec elif FLAGS.model_type == 'synthesizer_model': ground_truth = decode_program_str(targets[i]) best_prediction, score = eval_predicted_synthesizer_model( beams, inps, outs, decode_program) decode_to_str_fn = decode_program_str else: raise ValueError( f'Unknown model type {FLAGS.model_type}') if score > 0: total_successes += 1 total_denominator += 1 beams_target = [ decode_to_str_fn(beam) for beam in beams ] ios.append(' ; '.join(map(str, zip(inps, outs)))) targets_list.append(ground_truth) predictions.append(best_prediction) scores.append(score) logging.info('') logging.info('ios: %s', ios[-1]) logging.info('targets[%s]: %s', i, targets[i]) logging.info('ground_truth: %s', ground_truth) logging.info('predicted beam: %s', '\n'.join(beams_target)) logging.info('best_prediction: %s', best_prediction) logging.info('score: %s', score) logging.info('beams: %s', beams) if not ground_truth: logging.warn('ground_truth is empty!') top_of_beam = [] for index, beam in enumerate(beams[:-5:-1]): top_of_beam.append( 'index: {}, decoded: {}, tokens: {}'. format(index, decode_to_str_fn(beam), beam)) top_of_beams.append('\n\n'.join(top_of_beam)) all_total_successes, all_total_denominator = per_host_sum_pmap( jax.tree_map(np.array, (total_successes, total_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = ( f'ios: {ios[n]}\n\ntarget: {targets_list[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'score: {scores[n]}\n\n' f'top of beam:\n\n{top_of_beams[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: accuracy = 100 * all_total_successes / all_total_denominator logging.info( '%s results, step %d, beam size %d: %s / %s = %.2f%% (%.2f s)', predict_or_test, step, beam_size, all_total_successes, all_total_denominator, accuracy, time.time() - t_inference_start) summary_writer.scalar( '{}/beam-size-{}'.format(predict_or_test, beam_size), accuracy, step) summary_writer.text( '{}-samples-beam-{}'.format( predict_or_test, beam_size), '\n------\n'.join(message), step) summary_writer.flush() # Save a Checkpoint. Do this at the end of the training loop, so that if a # worker is descheduled during a round of prediction (which takes a while), # we will redo prediction upon restarting (to avoid losing data). if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step)