def predict_step(params, inputs, outputs, cache, eos_token, max_decode_len, beam_size, config): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. flat_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(config).apply( {'params': params}, inputs, outputs, method=models.ProgramTransformer.encode), beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(config).apply( {'params': params, 'cache': flat_cache}, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, bos_token=config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token, eos_token, max_decode_len, config, lp_config): """Predict translation with fast decoding beam search on a batch.""" params = state.optimizer.target lp_params = state.lp_optimizer.target # Split beam over latent sequences and programs. per_latent_beam_size = beam_size // FLAGS.latent_beam_size beam_size = FLAGS.latent_beam_size * per_latent_beam_size flat_lp_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, method=models.ProgramTransformer.encode), FLAGS.latent_beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, FLAGS.latent_beam_size) def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(lp_config).apply( { 'params': lp_params, 'cache': flat_lp_cache }, flat_ids, flat_lp_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_lp_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_lp_cache # Step 1: Beam-search over latent tokens. latent_beam_seqs, _ = decode.beam_search( inputs, lp_cache, tokens_ids_to_latent_logits, beam_size=FLAGS.latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32)) flat_latent_seqs = decode.flat_batch_beam_expand( decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size) # Quantize the predicted latent codes. flat_latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, flat_latent_seqs, mutable=False, method=models.LatentProgramTransformer.quantize) flat_encoded = decode.flat_batch_beam_expand( models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode), beam_size) # Padding masks. flat_latents_mask = jnp.where( jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token), 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state, 'cache': flat_cache }, flat_ids, flat_latents, flat_encoded, flat_latents_mask, flat_encoded_padding_mask, mutable=['cache'], method=models.LatentProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Step 2: Beam-search over program tokens. per_latent_inputs = decode.flat_batch_beam_expand(inputs, FLAGS.latent_beam_size) per_latent_cache = jax.tree_map( lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size), cache) beam_seqs, _ = decode.beam_search(per_latent_inputs, per_latent_cache, tokens_ids_to_logits, beam_size=per_latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Collapse both beam dimensions into one. beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) + beam_seqs.shape[2:]) latent_beam_seqs = jnp.repeat(latent_beam_seqs, per_latent_beam_size, axis=1) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs, latent_beam_seqs
def predict_step(params, inputs, outputs, cache, beam_size, num_partial_programs, max_decode_len, eos_token, config, use_expanding_layer, slow_decode=False, use_split_encoding=False, split_params=None, split_outputs=None): """Predict translation with fast decoding beam search on a batch.""" per_partial_beam_size = max(beam_size // num_partial_programs, 1) m = models.DecomposeExpandingLayerTransformer( config=config, num_partial_programs=num_partial_programs, use_expanding_layer=use_expanding_layer) # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. if use_split_encoding: # Use pretrained split model to get encodings assert (split_params is not None) and (split_outputs is not None) split_inputs = models.add_and_tile_dim(inputs, num_partial_programs, axis=1) # split_outputs shape == [batch_size, num_partial, num_io, length] split_outputs = jnp.swapaxes(split_outputs, 1, 2) encoded = models.DecomposeExpandingLayerTransformer( config=config, num_partial_programs=1, use_expanding_layer=False).apply( {'params': split_params}, split_inputs, split_outputs, method=models.DecomposeExpandingLayerTransformer.encode) flat_encoded = decode.flat_batch_beam_expand(encoded, per_partial_beam_size) encoded_padding_mask = jnp.where( split_outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, per_partial_beam_size) else: flat_encoded = decode.flat_batch_beam_expand( m.apply( {'params': params}, inputs, outputs, method=models.DecomposeExpandingLayerTransformer.encode), per_partial_beam_size) flat_encoded = m.apply( {'params': params}, flat_encoded, method=models.DecomposeExpandingLayerTransformer.decompose) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, per_partial_beam_size) flat_encoded_padding_mask = models.add_and_tile_dim( flat_encoded_padding_mask, num_partial_programs, axis=1) if slow_decode: def tokens_ids_to_logits(flat_ids, i): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeExpandingLayerTransformer( config=config, num_partial_programs=num_partial_programs, use_expanding_layer=use_expanding_layer).apply( {'params': params}, flat_ids, flat_encoded[:, i], flat_encoded_padding_mask[:, i], method=models.DecomposeExpandingLayerTransformer.decode) return flat_logits else: def tokens_ids_to_logits(flat_ids, flat_cache, i): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeExpandingLayerTransformer( config=config, num_partial_programs=num_partial_programs, use_expanding_layer=use_expanding_layer).apply( {'params': params, 'cache': flat_cache}, flat_ids, flat_encoded[:, i], flat_encoded_padding_mask[:, i], mutable=['cache'], method=models.DecomposeExpandingLayerTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Perform beam search independently for each partial program. all_beam_seqs = [] all_beam_log_probs = [] for i in range(num_partial_programs): beam_seqs, beam_log_probs = decode.beam_search( inputs, cache, functools.partial(tokens_ids_to_logits, i=i), beam_size=per_partial_beam_size, alpha=0.6, bos_token=config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len, slow_decode=slow_decode) all_beam_seqs.append(beam_seqs) all_beam_log_probs.append(beam_log_probs) all_beam_seqs = jnp.stack(all_beam_seqs, axis=1) all_beam_log_probs = jnp.stack(all_beam_log_probs, axis=1) # all_beam_seqs shape == [batch, n_partial, n_beam_per_partial, length] assert len(all_beam_seqs.shape) == 4 # all_beam_log_probs shape == [batch, n_partial, n_beam_per_partial] assert len(all_beam_log_probs.shape) == 3 # Sort beams in order of decreasing probability. order = jnp.argsort(all_beam_log_probs, axis=2)[:, :, ::-1] all_beam_log_probs = jnp.take_along_axis(all_beam_log_probs, order, axis=2) all_beam_seqs = jnp.take_along_axis(all_beam_seqs, order[Ellipsis, jnp.newaxis], axis=2) return all_beam_seqs, all_beam_log_probs
def predict_step(params, inputs, outputs, cache, beam_size, eos_token, max_decode_len, config, dropout_rng=None, slow_decode=True): """Predict translation with fast decoding beam search on a batch.""" # This code is necessary to experiment with using dropout during prediction, # but we don't normally use dropout here. dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. flat_encoded = decode.flat_batch_beam_expand( models.DecomposeAttentionTransformer(config).apply( {'params': params}, inputs, outputs, method=models.DecomposeAttentionTransformer.encode, rngs={'dropout': dropout_rng}), beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) if slow_decode: def tokens_ids_to_logits(flat_ids): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits = models.DecomposeAttentionTransformer( config=config).apply( {'params': params}, flat_ids, flat_encoded, flat_encoded_padding_mask, method=models.DecomposeAttentionTransformer.decode, rngs={'dropout': dropout_rng}) return flat_logits else: def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.DecomposeAttentionTransformer( config=config).apply( { 'params': params, 'cache': flat_cache }, flat_ids, flat_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.DecomposeAttentionTransformer.decode, rngs={'dropout': dropout_rng}) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search(inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, bos_token=config.base_config.bos_token, eos_token=eos_token, max_decode_len=max_decode_len, slow_decode=slow_decode) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs, new_dropout_rng