def initialize_cache(inputs, outputs, programs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (programs.shape[0], max_decode_len) initial_variables = models.ProgramTransformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(outputs.shape, config.dtype), jnp.ones(target_shape, config.dtype)) return initial_variables['cache']
def predict_step(params, inputs, outputs, cache, eos_token, max_decode_len, beam_size, config): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. flat_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(config).apply( {'params': params}, inputs, outputs, method=models.ProgramTransformer.encode), beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(config).apply( {'params': params, 'cache': flat_cache}, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, bos_token=config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs
def loss_fn(params): """Loss function used for training.""" logits = models.ProgramTransformer(config).apply( {'params': params}, inputs, outputs, programs, rngs={'dropout': train_rng}) loss, weight_sum = compute_weighted_cross_entropy(logits, programs, weights) mean_loss = loss / weight_sum return mean_loss, logits
def loss_fn(params, lp_params): """Loss function used for training.""" latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices, rngs={'dropout': train_rng}) latent_loss, latent_weight_sum = compute_weighted_cross_entropy( latent_logits, latent_indices, latent_weights) # End-to-end prediction. encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.decode) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) mean_loss = latent_loss / latent_weight_sum if not pretrain: mean_loss += loss / weight_sum return mean_loss, (logits, latent_logits)
def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(lp_config).apply( {'params': lp_params, 'cache': flat_lp_cache}, flat_ids, flat_lp_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_lp_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_lp_cache
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), use_relative_attention=FLAGS.use_relative_attention, deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_lib.initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial(train_lib.predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(4, 5, 6)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: train_lib.pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = p_init_cache(inputs, outputs, programs) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, eos_token, programs.shape[-1], beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info( 'Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def test_train(self): tf.enable_v2_behavior() tf.random.set_seed(0) np.random.seed(0) random.seed(0) dataset_filepattern = os.path.join( os.path.dirname(__file__), 'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*' ) print('dataset_filepattern = {}'.format(dataset_filepattern)) batch_size = 4 num_strings_per_task = 4 max_characters = 10 max_program_length = 15 # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} _, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] # Load dataset. dataset = input_pipeline.create_dataset_from_tf_record( dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=((num_strings_per_task, max_characters), (num_strings_per_task, max_characters), (max_program_length, )), drop_remainder=True) dataset_iter = dataset.repeat().as_numpy_iterator() train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=32, num_heads=4, num_layers=2, qkv_dim=32, mlp_dim=32, max_len=max(max_characters, max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, max_program_length), jnp.float32)) optimizer_def = optim.Adam(1e-2, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=0.1) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=1e-2) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') # Training loop. start_step = 0 rngs = jax.random.split(rng, jax.local_device_count()) del rng for _ in range(start_step, 1000): inputs, outputs, programs = common_utils.shard(next(dataset_iter)) optimizer, _, rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=rngs) # Evaluation. eval_metrics = [] for batches in dataset.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: self.assertGreater(eval_summary['accuracy'], 0.1)
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token, eos_token, max_decode_len, config, lp_config): """Predict translation with fast decoding beam search on a batch.""" params = state.optimizer.target lp_params = state.lp_optimizer.target # Split beam over latent sequences and programs. per_latent_beam_size = beam_size // FLAGS.latent_beam_size beam_size = FLAGS.latent_beam_size * per_latent_beam_size flat_lp_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, method=models.ProgramTransformer.encode), FLAGS.latent_beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, FLAGS.latent_beam_size) def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(lp_config).apply( { 'params': lp_params, 'cache': flat_lp_cache }, flat_ids, flat_lp_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_lp_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_lp_cache # Step 1: Beam-search over latent tokens. latent_beam_seqs, _ = decode.beam_search( inputs, lp_cache, tokens_ids_to_latent_logits, beam_size=FLAGS.latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32)) flat_latent_seqs = decode.flat_batch_beam_expand( decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size) # Quantize the predicted latent codes. flat_latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, flat_latent_seqs, mutable=False, method=models.LatentProgramTransformer.quantize) flat_encoded = decode.flat_batch_beam_expand( models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode), beam_size) # Padding masks. flat_latents_mask = jnp.where( jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token), 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state, 'cache': flat_cache }, flat_ids, flat_latents, flat_encoded, flat_latents_mask, flat_encoded_padding_mask, mutable=['cache'], method=models.LatentProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Step 2: Beam-search over program tokens. per_latent_inputs = decode.flat_batch_beam_expand(inputs, FLAGS.latent_beam_size) per_latent_cache = jax.tree_map( lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size), cache) beam_seqs, _ = decode.beam_search(per_latent_inputs, per_latent_cache, tokens_ids_to_logits, beam_size=per_latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Collapse both beam dimensions into one. beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) + beam_seqs.shape[2:]) latent_beam_seqs = jnp.repeat(latent_beam_seqs, per_latent_beam_size, axis=1) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs, latent_beam_seqs
def eval_step(state, inputs, outputs, programs, bos_token, eos_token, config, lp_config): """Evaluate on batch of program tasks.""" params = state.optimizer.target lp_params = state.lp_optimizer.target weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) # Embedding mask for autoencoding. emb_mask = jnp.ones((1, FLAGS.latent_vocab_size), jnp.float32).at[:, [0, bos_token, eos_token]].set(0) ae_logits, vq = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, mutable=False) # Postprocess latent indices. latent_indices = add_eos_token(vq['latent_indices'], eos_token) latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32) encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) # Additionally mask out eos token in latents. latents_mask = jnp.where( jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1, 0).astype(jnp.float32) latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices) encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, method=models.LatentProgramTransformer.decode) metrics = compute_metrics(logits, programs, weights) metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_')) latent_metrics = compute_metrics(latent_logits, latent_indices, latent_weights, prefix='latent_') return metrics, latent_metrics
def eval_step(params, inputs, outputs, programs, config): weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) logits = models.ProgramTransformer(config).apply({'params': params}, inputs, outputs, programs) return compute_metrics(logits, programs, weights)