def sample(self, inputs, temperature=1.0): # No need for LogSoftmax with sampling - softmax normalization is # subtracting a constant from every logit, and sampling is taking # a max over logits plus noise, so invariant to adding a constant. if temperature == 0.0: return jnp.argmax(self._unflatten_inputs(inputs), axis=-1) return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature)
def autoregressive_sample_stream(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True): """Yields samples from `model`, in autoregressive language model fashion. This function uses `model` to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position's input, and further calls to `autoregressive_sample_stream` repeat the process for successive positions indefinitely. Inputs and outputs always come in batches, even if size 1. If `inputs` is present, it must have shape (`batch_size`, inputs_sequence_length), and each output in the stream has shape (`batch_size`, 1). Args: model: A layer object (subclass of `trax.layers.Layer`) created in `'predict'` mode and initialized from trained weights. The model must have a structure that allows it to run as an autoregressive one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`). inputs: Sequence of symbols the model sees as input the first time it generates an output. If None, the model generates the first output based on just the start symbol. batch_size: Number of sequences to generate in parallel as a batch. temperature: Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability). start_id: Integer representing the start symbol for the autoregressive process. accelerate: If True, create an accelerated version of `model` and use it for generating outputs. Yields: Tensor of integers with shape (`batch_size`, 1), representing the batch of outputs for the next position in the stream. """ if inputs is not None and inputs.shape[0] != batch_size: raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match ' f'batch_size arg ({batch_size}.') fast_model = tl.Accelerate(model) if accelerate else model start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) if model.n_in == 1 and inputs is not None: current_symbols = np.concatenate([start_symbol, inputs], axis=1) else: current_symbols = start_symbol while True: if model.n_in > 1 and inputs is not None: logits = fast_model((inputs, current_symbols))[0] else: logits = fast_model(current_symbols) sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature) yield sample # NOTE: Because the model is autoregressive and in 'predict' mode, its # history is cached in the model state and the next input is the single # symbol just sampled. current_symbols = sample[:, None]
def generate_output(model, inputs, limit=155): counter = 0 current_symbols = [0] sample = None while counter < limit and sample != [1]: logits = model((inputs, current_symbols))[0] sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=0) current_symbols.append(sample) counter += 1 return current_symbols
def next_symbol(NMTAttn_model, input_tokens, cur_output_tokens, temperature): token_length = len(cur_output_tokens) padded_length = 2**int(np.ceil(np.log2(token_length + 1))) padded = cur_output_tokens + [0] * (padded_length - token_length) # model expects the output to have an axis for the batch size in front so # convert `padded` list to a numpy array with shape (x, <padded_length>) where the # x position is the batch axis. (hint: you can use np.expand_dims() with axis=0 to insert a new axis) padded_with_batch = np.expand_dims(padded, axis=0) # get the model prediction. remember to use the `NMAttn` argument defined above. # hint: the model accepts a tuple as input (e.g. `my_model((input1, input2))`) output, _ = NMTAttn_model((input_tokens, padded_with_batch)) # get log probabilities from the last token output log_probs = output[0, token_length, :] # get the next symbol by getting a logsoftmax sample (*hint: cast to an int) symbol = int(tl.logsoftmax_sample(log_probs, temperature)) return symbol, float(log_probs[symbol])
def autoregressive_sample_stream(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True): """Stream autoregressive samples from the provided model. Note that the provided model should be an autoregressive model initialized in 'predict' mode. In this mode, a model takes the outputs it is generating one-by-one (instead of taking them all at once, as, e.g., during training). Model state is used to store the intermediate information needed, and usually the model perfoms inference in this mode faster than in 'eval' mode. Args: model: instance of trax.Layer, the model to sample from (at mode='predict') inputs: optional tensor [batch_size, M]: inputs to provide to the model; for language models (with n_in=1) we use inputs as prefix to the model batch_size: how many batches to sample (default: 1) temperature: sampling temperature (default: 1.0) start_id: int, id for the start symbol fed at the beginning (default: 1) accelerate: whether to accelerate the model before decoding (default: True) Yields: Tensor of ints of shape [batch_size] containing subsequent autoregressive samples from the model. """ if inputs is not None and inputs.shape[0] != batch_size: raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.') fast_model = tl.Accelerate(model) if accelerate else model cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) if inputs is not None and model.n_in == 1: # use inputs as prefix cur_symbol = np.concatenate([cur_symbol, inputs], axis=1) while True: model_input = cur_symbol if inputs is not None and model.n_in > 1: model_input = (inputs, cur_symbol) logits = fast_model(model_input) if inputs is not None and model.n_in > 1: logits = logits[0] # Pick first element from model output (a pair here) sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature) yield sample # Note: we're using 'predict' mode autoregressive models here, so history # is caches in the model state and we are only feeding one symbol next. cur_symbol = sample[:, None]
def autoregressive_sample_stream(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True, eval_mode=False, eval_min_length=1): """Yields samples from `model`, in autoregressive language model fashion. This function uses `model` to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position's input, and further calls to `autoregressive_sample_stream` repeat the process for successive positions indefinitely. Inputs and outputs always come in batches, even if size 1. If `inputs` is present, it must have shape (`batch_size`, inputs_sequence_length), and each output in the stream has shape (`batch_size`, 1). Args: model: A layer object (subclass of `trax.layers.Layer`) created in `'predict'` mode and initialized from trained weights. The model must have a structure that allows it to run as an autoregressive one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), except if `eval_mode` is set -- any model can be sampled then, but the sampling process may be much slower. inputs: Sequence of symbols the model sees as input the first time it generates an output. If None, the model generates the first output based on just the start symbol. batch_size: Number of sequences to generate in parallel as a batch. temperature: Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability). start_id: Integer representing the start symbol for the autoregressive process, or array of shape (`batch_size`, 1) of such integers. accelerate: If True, create an accelerated version of `model` and use it for generating outputs. eval_mode: If True, assume the model is created in `eval` mode and sample by collecting all previous outputs and passing the whole tensor. eval_min_length: If set, the minimum length to pad to in eval mode. Yields: Tensor of integers with shape (`batch_size`, 1), representing the batch of outputs for the next position in the stream. """ if inputs is not None and inputs.shape[0] != batch_size: raise ValueError( f'Inputs batch size ({inputs.shape[0]}) does not match ' f'batch_size arg ({batch_size}.') fast_model = tl.Accelerate(model) if accelerate else model if np.isscalar(start_id): start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) else: start_symbol = start_id if model.n_in == 1 and inputs is not None: current_symbols = np.concatenate([start_symbol, inputs], axis=1) else: current_symbols = start_symbol if eval_mode: # no start symbol needed in eval mode current_symbols = current_symbols[:, 1:] while True: # Pad inputs to power-of-2 length if needed. if eval_mode: # one extra symbol as an initial one will be added l = max(eval_min_length, current_symbols.shape[1] + 1) pad_len = int(2**np.ceil(np.log2(l))) - current_symbols.shape[1] unpadded_symbols = current_symbols current_symbols = np.pad(current_symbols, [[0, 0], [0, pad_len]], mode='constant') last_index = -pad_len # no -1 as the starting one will be added else: last_index = -1 # Run the model. if model.n_in > 1 and inputs is not None: logits = fast_model((inputs, current_symbols))[0] else: logits = fast_model(current_symbols) logits = tl.log_softmax(logits[:, last_index, :]) sample = tl.logsoftmax_sample(logits, temperature=temperature) yield sample if eval_mode: current_symbols = np.concatenate( [unpadded_symbols, sample[:, None]], axis=1) else: # NOTE: Because the model is autoregressive and in 'predict' mode, its # history is cached in the model state and the next input is the single # symbol just sampled. current_symbols = sample[:, None]