Exemplo n.º 1
0
  def __init__(self, loss_layer, optimizer, n_devices=None):
    self._loss_layer = loss_layer
    self._optimizer = optimizer
    self._n_devices = n_devices or fastmath.device_count()

    # optimizer slots and opt_params may need to be replicated
    self._slots, self._opt_params = tl.for_n_devices(
        (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)

    # accelerated version of loss layer to replicate weights and state
    self._accelerated_loss_layer = tl.Accelerate(
        loss_layer, n_devices=n_devices)

    # Signature:
    # (batch, weights, state, rng) -> ((loss, state), gradients)
    self._forward_and_backward_fn = (
        fastmath.value_and_grad(
            loss_layer.pure_fn,
            argnums=1,  # arg1 of pure_fn: weights
            has_aux=True))  # return (loss, state), gradients

    # Signature:
    # (weights, slots), step, opt_params, batch, state, rng ->
    # (weights, slots), state, stats
    self._accelerated_update_fn = (
        _accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
        )
    )
Exemplo n.º 2
0
def autoregressive_sample_stream(model, inputs=None,
                                 batch_size=1, temperature=1.0,
                                 start_id=0, accelerate=True):
  """Yields samples from `model`, in autoregressive language model fashion.

  This function uses `model` to generate outputs one position at a time, with
  access to inputs for the current position and all preceding positions. The
  new output becomes the next position's input, and further calls to
  `autoregressive_sample_stream` repeat the process for successive positions
  indefinitely.

  Inputs and outputs always come in batches, even if size 1. If `inputs` is
  present, it must have shape (`batch_size`, inputs_sequence_length), and each
  output in the stream has shape (`batch_size`, 1).

  Args:
    model: A layer object (subclass of `trax.layers.Layer`) created in
        `'predict'` mode and initialized from trained weights. The model
        must have a structure that allows it to run as an autoregressive
        one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`).
    inputs: Sequence of symbols the model sees as input the first time it
        generates an output. If None, the model generates the first output
        based on just the start symbol.
    batch_size: Number of sequences to generate in parallel as a batch.
    temperature: Parameter that controls the sharpness of the softmax that
        feeds the sampling process. Values range from 0.0 (all probability mass
        goes to one candidate; like an argmax) to positive infinity (all
        candidates have equal probability).
    start_id: Integer representing the start symbol for the autoregressive
        process.
    accelerate: If True, create an accelerated version of `model` and use it
        for generating outputs.

  Yields:
    Tensor of integers with shape (`batch_size`, 1), representing the batch of
    outputs for the next position in the stream.
  """
  if inputs is not None and inputs.shape[0] != batch_size:
    raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match '
                     f'batch_size arg ({batch_size}.')

  fast_model = tl.Accelerate(model) if accelerate else model
  start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
  if model.n_in == 1 and inputs is not None:
    current_symbols = np.concatenate([start_symbol, inputs], axis=1)
  else:
    current_symbols = start_symbol

  while True:
    if model.n_in > 1 and inputs is not None:
      logits = fast_model((inputs, current_symbols))[0]
    else:
      logits = fast_model(current_symbols)
    sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature)
    yield sample
    # NOTE: Because the model is autoregressive and in 'predict' mode, its
    # history is cached in the model state and the next input is the single
    # symbol just sampled.
    current_symbols = sample[:, None]
Exemplo n.º 3
0
  def test_loss_layer_timing(self):
    all_settings = [
        # The first run is sometimes slower, less reliable.
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False},

        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': False},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': False},

        {'output': 32000, 'input': 2048, 'prob': None,
         'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': True},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': True},
        {'output': 32000, 'input': 2048, 'prob': None,
         'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': True},
    ]

    messages = []
    for settings in all_settings:
      pred_model = tl.SparseDenseWithOptions(
          n_units=settings['output'],
          d_input=settings['input'],
          sparsity_type=settings['type'],
          sparsity=settings['sparsity'],
          d_lowrank=settings['lowrank'],
          prob_sparse=settings['prob'],
          use_bias=settings['use_bias'],
          mode='predict',
          )
      pred_model = tl.Accelerate(pred_model)

      shape1l = shapes.ShapeDtype((1, settings['input']))
      pred_model.init(input_signature=shape1l)
      inputs = np.ones((1, settings['input']))

      total_time = 0.0
      for counter in range(-50, 100):
        start_time = time.time()
        y = pred_model(inputs)
        self.assertEqual(y.shape, (1, settings['output']))
        elapsed_time = time.time() - start_time
        if counter >= 0:
          total_time += elapsed_time

      message = (
          '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n'
          % (_size_of_model(pred_model), settings, total_time))
      messages.append(message)
      print(message)

    print('Final results (recap):')
    for message in messages:
      print(message)
Exemplo n.º 4
0
 def test_accelerated_same_result(self):
     layer = tl.Dense(2)
     x = np.random.uniform(size=(8, 7))
     layer.init(shapes.signature(x))
     y = layer(x)
     z = tl.Accelerate(layer)(x)
     for i in range(8):
         self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4)
         self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4)
Exemplo n.º 5
0
    def test_accelerated_weighted_category_accuracy(self):
        """Test multi-device aggregation of weights."""
        layer = tl.Accelerate(tl.WeightedCategoryAccuracy())
        weights = np.array([1., 1., 1., 0.])
        targets = np.array([0, 1, 2, 3])

        model_outputs = np.array([[.2, .1, .7, 0.], [.2, .1, .7, 0.],
                                  [.2, .1, .7, 0.], [.2, .1, .7, 0.]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(np.mean(accuracy), 1 / 3)
Exemplo n.º 6
0
 def test_chunk_memory(self):
     """Test chunking here to exercise accelerator memory usage."""
     layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128))
     chunked = tl.Chunk(layer, 256)
     x = np.random.uniform(size=(16 * 1024, 16))
     chunked.init(shapes.signature(x))
     y = chunked(x)
     z = tl.Accelerate(chunked)(x)
     self.assertEqual(y.shape, (16 * 1024, 128))
     self.assertEqual(z.shape, (16 * 1024, 128))
Exemplo n.º 7
0
def autoregressive_sample(model,
                          prefix=None,
                          inputs=None,
                          batch_size=1,
                          temperature=1.0,
                          start_id=0,
                          eos_id=1,
                          max_length=100,
                          accelerate=True):
    """Perform aturegressive sampling from the provided model.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    prefix: optional tensor [batch_size, L]: prefix for decoding
    inputs: optional tensor [batch_size, M]: inputs to provide to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    eos_id: int, id of the end-of-sequence symbol used to stop (default: 1)
    max_length: maximum length to sample (default: 100)
    accelerate: whether to accelerate the model before decoding (default: True)

  Returns:
    a tensor of ints of shape [batch_size, N] with N <= max_length containing
    the autoregressively sampled output from the model
  """
    if prefix is not None and prefix.shape[0] != batch_size:
        raise ValueError(
            f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
    fast_model = tl.Accelerate(model) if accelerate else model
    cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    result = []
    for i in range(max_length):
        model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
        logits = fast_model(model_input)
        if inputs is not None:
            logits = logits[
                0]  # Pick first element from model output (a pair here)
        if prefix is not None and i < prefix.shape[1]:  # Read from prefix.
            cur_prefix_symbol = prefix[:, i]
            sample = cur_prefix_symbol[:, None]
        else:
            sample = tl.gumbel_sample(logits, temperature=temperature)
        result.append(sample)
        # Note: we're using 'predict' mode autoregressive models here, so history
        # is caches in the model state and we are only feeding one symbol next.
        cur_symbol = sample
        # TODO(lukaszkaiser): extend stopping below to batch_sizes > 1.
        if batch_size == 1 and int(sample[0, 0]) == eos_id:
            break
    return np.concatenate(result, axis=1)
Exemplo n.º 8
0
def autoregressive_sample_stream(model, inputs=None,
                                 batch_size=1, temperature=1.0,
                                 start_id=0, accelerate=True):
  """Stream autoregressive samples from the provided model.

  Note that the provided model should be an autoregressive model initialized
  in 'predict' mode. In this mode, a model takes the outputs it is generating
  one-by-one (instead of taking them all at once, as, e.g., during training).
  Model state is used to store the intermediate information needed, and usually
  the model perfoms inference in this mode faster than in 'eval' mode.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    inputs: optional tensor [batch_size, M]: inputs to provide to the model;
      for language models (with n_in=1) we use inputs as prefix to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    accelerate: whether to accelerate the model before decoding (default: True)

  Yields:
    Tensor of ints of shape [batch_size] containing subsequent
    autoregressive samples from the model.
  """
  if inputs is not None and inputs.shape[0] != batch_size:
    raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
  fast_model = tl.Accelerate(model) if accelerate else model
  cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
  if inputs is not None and model.n_in == 1:  # use inputs as prefix
    cur_symbol = np.concatenate([cur_symbol, inputs], axis=1)
  while True:
    model_input = cur_symbol
    if inputs is not None and model.n_in > 1:
      model_input = (inputs, cur_symbol)
    logits = fast_model(model_input)
    if inputs is not None and model.n_in > 1:
      logits = logits[0]  # Pick first element from model output (a pair here)
    sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature)
    yield sample
    # Note: we're using 'predict' mode autoregressive models here, so history
    # is caches in the model state and we are only feeding one symbol next.
    cur_symbol = sample[:, None]
Exemplo n.º 9
0
def prepare_model(model_file, batch_size=1):
  """Prepare the model."""
  mode = 'eval' if FLAGS.use_eval_mode else 'predict'
  print('Initializing the model in %s mode.' % mode, flush=True)

  # Read the model name from the gin file
  model_reference = gin.query_parameter(
      'trax.supervised.trainer_lib.train.model')
  model = model_reference.scoped_configurable_fn(mode=mode)

  dec_len = 32 if FLAGS.use_eval_mode else 1
  batch_size_pd = max(1, batch_size // jax.local_device_count())
  shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32)
  # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
  model.init_from_file(
      model_file, weights_only=True, input_signature=(shape11, shape11))
  model = tl.Accelerate(model)

  initial_state = model.state
  vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH)

  return vocab, model, initial_state
Exemplo n.º 10
0
    def __init__(self,
                 model_with_loss,
                 optimizer,
                 n_devices=None,
                 adasum=False):
        self._model_with_loss = model_with_loss
        self._optimizer = optimizer
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum

        # optimizer slots and opt_params may need to be replicated
        self._slots, self._opt_params = tl.on_cpu(
            tl.for_n_devices(
                (self._optimizer.slots, self._optimizer.opt_params),
                self._n_devices))

        # accelerated version of model+loss to replicate weights and state
        self._accelerated_model_with_loss = tl.Accelerate(model_with_loss,
                                                          n_devices=n_devices)

        # Signature:
        # (batch, weights, state, rng) -> ((loss, state), gradients)
        self._forward_and_backward_fn = (
            fastmath.value_and_grad(
                model_with_loss.pure_fn,
                argnums=1,  # arg1 of pure_fn: weights
                has_aux=True))  # return (loss, state), gradients

        # Signature:
        # (weights, slots), step, opt_params, batch, state, rng ->
        # (weights, slots), state, stats
        self._accelerated_update_fn = (_accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
            adasum=self._adasum))
Exemplo n.º 11
0
  def __init__(self, task,
               joint_model=None,
               optimizer=None,
               lr_schedule=lr.multifactor,
               batch_size=64,
               train_steps_per_epoch=500,
               supervised_evals_per_epoch=1,
               supervised_eval_steps=1,
               n_trajectories_per_epoch=50,
               max_slice_length=1,
               normalize_advantages=True,
               output_dir=None,
               n_replay_epochs=1):
    """Configures the joint trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      joint_model: Trax layer, representing the joint policy and value model.
      optimizer: the optimizer to use to train the joint model.
      lr_schedule: learning rate schedule to use to train the joint model/.
      batch_size: batch size used to train the joint model.
      train_steps_per_epoch: how long to train the joint model in each RL epoch.
      supervised_evals_per_epoch: number of value trainer evaluations per RL
          epoch - only affects metric reporting.
      supervised_eval_steps: number of value trainer steps per evaluation -
          only affects metric reporting.
      n_trajectories_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      normalize_advantages: if True, then normalize advantages - currently
          implemented only in PPO.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      n_replay_epochs: how many last epochs to take into the replay buffer;
           > 1 only makes sense for off-policy algorithms.
    """
    super().__init__(
        task,
        n_trajectories_per_epoch=n_trajectories_per_epoch,
        output_dir=output_dir,
    )
    self._batch_size = batch_size
    self._train_steps_per_epoch = train_steps_per_epoch
    self._supervised_evals_per_epoch = supervised_evals_per_epoch
    self._supervised_eval_steps = supervised_eval_steps
    self._n_trajectories_per_epoch = n_trajectories_per_epoch
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)
    self._lr_schedule = lr_schedule()
    self._optimizer = optimizer
    self._normalize_advantages = normalize_advantages
    self._n_replay_epochs = n_replay_epochs
    self._task.set_n_replay_epochs(n_replay_epochs)

    # Inputs to the joint model are produced by self.batches_stream.
    self._inputs = data.inputs.Inputs(
        train_stream=lambda _: self.batches_stream())

    self._joint_model = functools.partial(
        joint_model,
        policy_distribution=self._policy_dist,
    )

    # This is the joint Trainer that will be used to train the policy model.
    # * inputs to the trainer come from self.batches_stream
    # * outputs are passed to self._joint_loss
    self._trainer = supervised.Trainer(
        model=self._joint_model,
        optimizer=self._optimizer,
        lr_schedule=self._lr_schedule,
        loss_fn=self.joint_loss,
        inputs=self._inputs,
        output_dir=output_dir,
        metrics={'joint_loss': self.joint_loss,
                 'advantage_mean': self.advantage_mean,
                 'advantage_norm': self.advantage_norm,
                 'value_loss': self.value_loss,
                 'explained_variance': self.explained_variance,
                 'log_probs_mean': self.log_probs_mean,
                 'preferred_move': self.preferred_move})
    self._eval_model = tl.Accelerate(
        self._joint_model(mode='eval'), n_devices=1)
    example_batch = next(self.batches_stream())
    self._eval_model.init(example_batch)
Exemplo n.º 12
0
def autoregressive_sample(model,
                          prefix=None,
                          inputs=None,
                          batch_size=1,
                          temperature=1.0,
                          start_id=0,
                          eos_id=1,
                          max_length=100,
                          accelerate=True):
    """Perform aturegressive sampling from the provided model.

  Note that the provided model should be an autoregressive model initialized
  in 'predict' mode. In this mode, a model takes the outputs it is generating
  one-by-one (instead of taking them all at once, as, e.g., during training).
  Model state is used to store the intermediate information needed, and usually
  the model perfoms inference in this mode faster than in 'eval' mode.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    prefix: optional tensor [batch_size, L]: prefix for decoding
    inputs: optional tensor [batch_size, M]: inputs to provide to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    eos_id: int, id of the end-of-sequence symbol used to stop (default: 1)
    max_length: maximum length to sample (default: 100)
    accelerate: whether to accelerate the model before decoding (default: True)

  Returns:
    a tensor of ints of shape [batch_size, N] with N <= max_length containing
    the autoregressively sampled output from the model
  """
    if prefix is not None and prefix.shape[0] != batch_size:
        raise ValueError(
            f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
    fast_model = tl.Accelerate(model) if accelerate else model
    cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    result = []
    eos_seen = []
    for i in range(max_length):
        model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
        logits = fast_model(model_input)
        if inputs is not None:
            logits = logits[
                0]  # Pick first element from model output (a pair here)
        if prefix is not None and i < prefix.shape[1]:  # Read from prefix.
            cur_prefix_symbol = prefix[:, i]
            sample = cur_prefix_symbol[:, None]
        else:
            sample = tl.gumbel_sample(logits, temperature=temperature)
        result.append(sample)
        # Note: we're using 'predict' mode autoregressive models here, so history
        # is caches in the model state and we are only feeding one symbol next.
        cur_symbol = sample
        # Check at which batch positions have we already encountered EOS.
        for j in range(batch_size):
            if int(sample[j, 0]) == eos_id:
                eos_seen.append(j)
        # If EOS has been seen on all positions, stop.
        if all([j in eos_seen for j in range(batch_size)]):
            break
    return np.concatenate(result, axis=1)
Exemplo n.º 13
0
def autoregressive_sample_stream(model,
                                 inputs=None,
                                 batch_size=1,
                                 temperature=1.0,
                                 start_id=0,
                                 accelerate=True,
                                 eval_mode=False,
                                 eval_min_length=1):
    """Yields samples from `model`, in autoregressive language model fashion.

  This function uses `model` to generate outputs one position at a time, with
  access to inputs for the current position and all preceding positions. The
  new output becomes the next position's input, and further calls to
  `autoregressive_sample_stream` repeat the process for successive positions
  indefinitely.

  Inputs and outputs always come in batches, even if size 1. If `inputs` is
  present, it must have shape (`batch_size`, inputs_sequence_length), and each
  output in the stream has shape (`batch_size`, 1).

  Args:
    model: A layer object (subclass of `trax.layers.Layer`) created in
        `'predict'` mode and initialized from trained weights. The model
        must have a structure that allows it to run as an autoregressive
        one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`),
        except if `eval_mode` is set -- any model can be sampled then,
        but the sampling process may be much slower.
    inputs: Sequence of symbols the model sees as input the first time it
        generates an output. If None, the model generates the first output
        based on just the start symbol.
    batch_size: Number of sequences to generate in parallel as a batch.
    temperature: Parameter that controls the sharpness of the softmax that
        feeds the sampling process. Values range from 0.0 (all probability mass
        goes to one candidate; like an argmax) to positive infinity (all
        candidates have equal probability).
    start_id: Integer representing the start symbol for the autoregressive
        process, or array of shape (`batch_size`, 1) of such integers.
    accelerate: If True, create an accelerated version of `model` and use it
        for generating outputs.
    eval_mode: If True, assume the model is created in `eval` mode and sample
        by collecting all previous outputs and passing the whole tensor.
    eval_min_length: If set, the minimum length to pad to in eval mode.

  Yields:
    Tensor of integers with shape (`batch_size`, 1), representing the batch of
    outputs for the next position in the stream.
  """
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size ({inputs.shape[0]}) does not match '
            f'batch_size arg ({batch_size}.')

    fast_model = tl.Accelerate(model) if accelerate else model
    if np.isscalar(start_id):
        start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    else:
        start_symbol = start_id
    if model.n_in == 1 and inputs is not None:
        current_symbols = np.concatenate([start_symbol, inputs], axis=1)
    else:
        current_symbols = start_symbol

    if eval_mode:
        # no start symbol needed in eval mode
        current_symbols = current_symbols[:, 1:]

    while True:
        # Pad inputs to power-of-2 length if needed.
        if eval_mode:
            # one extra symbol as an initial one will be added
            l = max(eval_min_length, current_symbols.shape[1] + 1)
            pad_len = int(2**np.ceil(np.log2(l))) - current_symbols.shape[1]
            unpadded_symbols = current_symbols
            current_symbols = np.pad(current_symbols, [[0, 0], [0, pad_len]],
                                     mode='constant')
            last_index = -pad_len  # no -1 as the starting one will be added
        else:
            last_index = -1
        # Run the model.
        if model.n_in > 1 and inputs is not None:
            logits = fast_model((inputs, current_symbols))[0]
        else:
            logits = fast_model(current_symbols)
        logits = tl.log_softmax(logits[:, last_index, :])
        sample = tl.logsoftmax_sample(logits, temperature=temperature)
        yield sample
        if eval_mode:
            current_symbols = np.concatenate(
                [unpadded_symbols, sample[:, None]], axis=1)
        else:
            # NOTE: Because the model is autoregressive and in 'predict' mode, its
            # history is cached in the model state and the next input is the single
            # symbol just sampled.
            current_symbols = sample[:, None]
Exemplo n.º 14
0
def load_model(path):
    model = NMTAttn(mode='eval')
    model.init_from_file(path, weights_only=True)
    model = tl.Accelerate(model)
    return model
Exemplo n.º 15
0
  def _terraformer_decoding_time(self, settings):
    # Garbage collection influences the timing, so we turn it off.
    gc.disable()
    max_len = 16

    def _self_attention_fn():
      return functools.partial(
          tl.SelfAttention,
          predict_drop_len=2 * max_len,
          predict_mem_len=2 * max_len)

    def _causal_attention_fn():
      attn_layer, attn_kwargs = settings['attn']
      return functools.partial(
          attn_layer,
          max_inference_length=2 * max_len, **attn_kwargs)

    if settings['model'] == 'terraformer':
      pred_model = models.ConfigurableTerraformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          use_bfloat16=True,
      )
    elif settings['model'] == 'transformer':
      pred_model = models.ConfigurableTransformer(
          mode='predict',
          d_model=settings['d_model'],
          d_ff=settings['d_ff'],
          dropout=0.1,
          max_len=max_len,
          n_heads=settings['n_heads'],
          n_encoder_layers=settings['encoder_layers'],
          n_decoder_layers=settings['decoder_layers'],
          # encoder_attention_type=_self_attention_fn(),
          encoder_decoder_attention_type=_causal_attention_fn(),
          input_vocab_size=settings['vocab'],
          ff_sparsity=settings['ff_sparsity'],
          ff_use_sru=settings['ff_use_sru'],
          # ff_dropout=0.1,
          # ff_chunk_size=1024,
          # attention_chunk_size=1,
          # n_decoder_attention_layers=settings['attention_layers'],
          loss_sparsity=settings['loss_sparsity'],
          pos_axial_shape=None,
          # enc_dec_attention_sparsity=settings['enc_dec_sparsity'],
          # use_bfloat16=True,
      )
    else:
      assert False
    # We put acceleration outside of autoregressive_sample_stream, because
    # we want to have a separate run (separate input) for model compilation.
    pred_model = tl.Accelerate(pred_model)

    shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
    shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
    pred_model.init(input_signature=(shape1l, shape11))
    original_state = copy.deepcopy(pred_model.state)

    inputs_warmup = np.zeros((1, max_len), dtype=np.int32)
    inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len)

    # This is a warm-up run, for compilation.
    result, current_time = [], time.time()
    elapsed_warmup_times = []
    for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream(
        pred_model, inputs_warmup, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_warmup_times.append(current_time - start_time)

    # This is a real decoding timing run that we measure.
    pred_model.state = original_state
    result, current_time = [], time.time()
    elapsed_times = []
    for index, sample in zip(range(12), decoding.autoregressive_sample_stream(
        pred_model, inputs, temperature=0.0, accelerate=False)):
      del index  # unused
      result.append(sample[:, None])  # to be sure that the result is computed

      current_time, start_time = time.time(), current_time
      elapsed_times.append(current_time - start_time)
    peak_memory = _memory_usage()

    if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]):
      print('WARNING! High variance found in elapsed times! Settings: {} ; '
            'elapsed times: {} ; Probably more warm-up steps should be used, '
            'or model size should be increased.'.format(settings,
                                                        elapsed_times))
    # Check resulting shapes.
    s = np.concatenate(result, axis=1)
    self.assertEqual(s.shape[0], 1)
    self.assertEqual(s.shape[1], 12)
    model_size = int(_size_of_model(pred_model))

    # We delete the model weights, because in some situations they won't be
    # deleted automatically.
    _recurrent_delete(pred_model.weights)
    gc.enable()
    return model_size, elapsed_times, peak_memory
Exemplo n.º 16
0
    def __init__(
            self,
            loop,
            model=gin.REQUIRED,
            observation_serializer=gin.REQUIRED,
            action_serializer=gin.REQUIRED,
            eval_at=1000,
            eval_task=None,
            context_lengths=(1, ),
            horizon_lengths=(1, ),
            n_steps=1,
    ):
        """Initializes SerializedModelEvaluation.

    Args:
      loop: Instance of `trax.supervised.training.Loop`.
      model: Instance of `trax.rl.serialization_utils.SerializedModel`.
      observation_serializer: `trax.rl.space_serializer.Serializer` of the
        output sequence (observation sequence in RL environment models).
      action_serializer: `trax.rl.space_serializer.Serializer` of the
        input sequence (action sequence in RL environment models).
      eval_at: When to evaluate. Either int (every how many steps to evaluate),
        or a list of ints (step numbers), or a function int -> bool (step
        predicate).
      eval_task: Instance of `trax.supervised.training.EvalTask` with the
        evaluation data, or None. If not provided, the task will be taken from
        `loop`.
      context_lengths: List of lengths of the context sequence fed into the
        model before starting prediction.
      horizon_lengths: List of lengths of the predicted sequence.
      n_steps: Number of batches to run evaluation for.
    """
        super().__init__(loop)

        self._model = tl.Accelerate(model)
        self._obs_serializer = observation_serializer
        self._act_serializer = action_serializer

        if isinstance(eval_at, int):
            self._eval_at = lambda step: step % eval_at == 1
        elif hasattr(eval_at, '__in__'):
            self._eval_at = lambda step: step in eval_at
        elif callable(eval_at):
            self._eval_at = eval_at
        else:
            raise TypeError(f'Unsupported type for eval_at: {type(eval_at)}.')

        if eval_task is None:
            if len(loop.eval_tasks) != 1:
                raise ValueError(
                    'If eval_task is not provided, the number of eval_tasks registered '
                    'in Loop must be exactly 1.')
            eval_task = loop.eval_tasks[0]
        self._eval_task = eval_task

        self._context_lengths = list(sorted(context_lengths))
        self._horizon_lengths = list(sorted(horizon_lengths))
        self._n_steps = n_steps

        self._batch_size = eval_task.sample_batch[0].shape[0]
        (_, self._init_state) = model.init(
            shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32))
Exemplo n.º 17
0
    def __init__(self,
                 task,
                 policy_model=None,
                 policy_optimizer=None,
                 policy_lr_schedule=lr.multifactor,
                 policy_batch_size=64,
                 policy_train_steps_per_epoch=500,
                 policy_evals_per_epoch=1,
                 policy_eval_steps=1,
                 n_eval_episodes=0,
                 only_eval=False,
                 max_slice_length=1,
                 output_dir=None,
                 **kwargs):
        """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super().__init__(task,
                         n_eval_episodes=n_eval_episodes,
                         output_dir=output_dir,
                         **kwargs)
        self._policy_batch_size = policy_batch_size
        self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
        self._policy_evals_per_epoch = policy_evals_per_epoch
        self._policy_eval_steps = policy_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)

        # Inputs to the policy model are produced by self._policy_batches_stream.
        self._policy_inputs = data.inputs.Inputs(
            train_stream=lambda _: self.policy_batches_stream())

        policy_model = functools.partial(
            policy_model,
            policy_distribution=self._policy_dist,
        )

        # This is the policy Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.policy_batches_stream
        # * outputs, targets and weights are passed to self.policy_loss
        self._policy_trainer = supervised.Trainer(
            model=policy_model,
            optimizer=policy_optimizer,
            lr_schedule=policy_lr_schedule(),
            loss_fn=self.policy_loss,
            inputs=self._policy_inputs,
            output_dir=output_dir,
            metrics=self.policy_metrics,
        )
        self._policy_collect_model = tl.Accelerate(
            policy_model(mode='collect'), n_devices=1)
        policy_batch = next(self.policy_batches_stream())
        self._policy_collect_model.init(shapes.signature(policy_batch))
        self._policy_eval_model = tl.Accelerate(
            policy_model(mode='eval'), n_devices=1)  # Not collecting stats
        self._policy_eval_model.init(shapes.signature(policy_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()
Exemplo n.º 18
0
  def __init__(self, task,
               value_body=None,
               value_optimizer=None,
               value_lr_schedule=lr.multifactor,
               value_batch_size=64,
               value_train_steps_per_epoch=500,
               value_evals_per_epoch=1,
               value_eval_steps=1,
               exploration_rate=functools.partial(
                   lr.multifactor,
                   factors='constant * decay_every',
                   constant=1.,  # pylint: disable=redefined-outer-name
                   decay_factor=0.99,
                   steps_per_decay=1,
                   minimum=0.1),
               n_eval_episodes=0,
               only_eval=False,
               n_replay_epochs=1,
               max_slice_length=1,
               sync_freq=1000,
               scale_value_targets=True,
               output_dir=None,
               **kwargs):
    """Configures the value trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      value_body: Trax layer, representing the body of the value model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      value_optimizer: the optimizer to use to train the policy model.
      value_lr_schedule: learning rate schedule to use to train the policy.
      value_batch_size: batch size used to train the policy model.
      value_train_steps_per_epoch: how long to train policy in each RL epoch.
      value_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      value_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      exploration_rate: exploration rate schedule - used in the policy method.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      max_slice_length: the maximum length of trajectory slices to use; it is
          the second dimenions of the value network output:
          (batch, max_slice_length, number of actions)
          Higher max_slice_length implies that the network has to predict more
          values into the future.
      sync_freq: frequency when to synchronize the target
        network with the trained network. This is necessary for training the
        network on bootstrapped targets, e.g. using n-step returns.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`. We are trying to fix the problem with very large
          returns in some games in a way which does not introduce an additional
          hyperparameters.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
    super(ValueAgent, self).__init__(
        task,
        n_eval_episodes=n_eval_episodes,
        output_dir=output_dir,
        **kwargs
    )
    self._value_batch_size = value_batch_size
    self._value_train_steps_per_epoch = value_train_steps_per_epoch
    self._value_evals_per_epoch = value_evals_per_epoch
    self._value_eval_steps = value_eval_steps
    self._only_eval = only_eval
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)
    self._n_replay_epochs = n_replay_epochs

    self._exploration_rate = exploration_rate()
    self._sync_at = (lambda step: step % sync_freq == 0)

    if scale_value_targets:
      self._value_network_scale = 1 / (1 - self._task.gamma)
    else:
      self._value_network_scale = 1

    value_model = functools.partial(
        models.Quality,
        body=value_body,
        n_actions=self.task.action_space.n)

    self._value_eval_model = value_model(mode='eval')
    self._value_eval_model.init(self._value_model_signature)
    self._value_eval_jit = tl.jit_forward(
        self._value_eval_model.pure_fn, fastmath.device_count(), do_mean=False)

    # Inputs to the value model are produced by self._values_batches_stream.
    self._inputs = data.inputs.Inputs(
        train_stream=lambda _: self.value_batches_stream())

    # This is the value Trainer that will be used to train the value model.
    # * inputs to the trainer come from self.value_batches_stream
    # * outputs, targets and weights are passed to self.value_loss
    self._value_trainer = supervised.Trainer(
        model=value_model,
        optimizer=value_optimizer,
        lr_schedule=value_lr_schedule(),
        loss_fn=self.value_loss,
        inputs=self._inputs,
        output_dir=output_dir,
        metrics={'value_loss': self.value_loss,
                 'value_mean': self.value_mean,
                 'returns_mean': self.returns_mean}
    )
    value_batch = next(self.value_batches_stream())
    self._eval_model = tl.Accelerate(
        value_model(mode='collect'), n_devices=1)
    self._eval_model.init(shapes.signature(value_batch))
    if self._task._initial_trajectories == 0:
      self._task.remove_epoch(0)
      self._collect_trajectories()
Exemplo n.º 19
0
def beam_search(model,
                inputs=None,
                batch_size=1,
                n_beams=2,
                start_id=0,
                eos_id=1,
                max_length=100,
                length_penalty=1.0,
                accelerate=True):
    """Returns a batch of n_beams-sequences created by beam search.

  This function uses `model` to generate outputs one position at a time, with
  access to inputs for the current position and all preceding positions. The
  new output becomes the next position's input, and this loop repeats until
  either the model outputs the `eos_id` value or the output sequence reaches
  `max_length` items -- but keeping n_beams top beams.

  Args:
    model: A layer object (subclass of `trax.layers.Layer`) created in
        `'predict'` mode and initialized from trained weights. The model
        must have a structure that allows it to run as autoregressive
        one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`).
    inputs: Sequence of symbols the model sees as input the first time it
        generates an output. If None, the model must generate the first output
        with no input to guide it.
    batch_size: Number of sequences to generate in parallel as a batch.
    n_beams: How many beams to consider at the same time.
    start_id: The start symbol (ID/integer) for the autoregressive process,
        or array of shape (`batch_size`, 1) of such integers.
    eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive
        process.
    max_length: Maximum length for generated sequences.
    length_penalty: Factor alpha in calculating the length penalty for beams.
    accelerate: If True, create an accelerated version of `model` and use it
        for generating outputs.

  Returns:
    Tensor of integers with shape (`batch_size`, n_beams, output_length) with
    a batch of output sequences. output_length is the maximum length of the
    output sequences, where each sequence can be no longer than `max_length`.
  """
    del eos_id, length_penalty  # TODO(lukaszkaiser): add length penalty, eos
    assert batch_size == 1, 'Batch size > 1 not supported yet'
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size ({inputs.shape[0]}) does not match '
            f'batch_size arg ({batch_size}.')

    fast_model = tl.Accelerate(model) if accelerate else model
    if np.isscalar(start_id):
        start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    else:
        start_symbol = start_id
    if model.n_in == 1 and inputs is not None:
        current_symbols = np.concatenate([start_symbol, inputs], axis=1)
    else:
        current_symbols = start_symbol

    beams = [current_symbols for _ in range(n_beams)]
    results = [([], 0.0) for _ in range(n_beams)]
    states = [fast_model.state for _ in range(n_beams)]
    top_k = [None] * n_beams
    counter = 0
    while counter < max_length:
        counter += 1
        # Run the model on all beams, collect states and top_k for each beam.
        for beam_id in range(n_beams if counter > 1 else 1):
            fast_model.state = states[beam_id]
            if model.n_in > 1 and inputs is not None:
                logits = fast_model((inputs, beams[beam_id]))[0]
            else:
                logits = fast_model(beams[beam_id])
            logits = tl.log_softmax(logits[:, -1, :])
            states[beam_id] = fast_model.state
            top_k[beam_id] = fastmath.top_k(logits, k=n_beams)

        # Select new beams.
        cur_values = []  # will hold triples (sum-of-logprobs, beam-id, symbol)
        for beam_id in range(n_beams if counter > 1 else 1):
            for k in range(n_beams):
                values, symbols = top_k[beam_id]
                value, symbol = values[:, k], symbols[:, k]
                cur_values.append(
                    (results[beam_id][1] + value, beam_id, symbol))
        cur_values.sort(key=lambda x: -x[0][0])  # x[0][0] as batch_size=1
        # Collect top beams to the new states and results.
        new_results, new_states, new_beams = [], [], []
        for (value, beam_id, symbol) in cur_values[:n_beams]:
            new_results.append((results[beam_id][0] + [symbol], value))
            new_states.append(states[beam_id])  # copy?
            new_beams.append(symbol[:, None])
        results, states, beams = new_results, new_states, new_beams

    return [(np.stack(r, axis=-1), v) for (r, v) in results]
Exemplo n.º 20
0
    def __init__(
            self,
            loop,
            model=None,
            eval_at=1000,
            eval_task=None,
            context_lengths=(1, ),
            horizon_lengths=(1, ),
            n_steps=1,
            accelerate_model=True,
    ):
        """Initializes SerializedModelEvaluation.

    Args:
      loop: Instance of `trax.supervised.training.Loop` or `None`. Can be set to
        `None` for testing - in such a case, `model` and `eval_task` must be
        provided.
      model: Instance of `trax.rl.serialization_utils.SerializedModel`. Not
        required if `loop` is provided.
      eval_at: When to evaluate. Either int (every how many steps to evaluate),
        or a list of ints (step numbers), or a function int -> bool (step
        predicate).
      eval_task: Instance of `trax.supervised.training.EvalTask` with the
        evaluation data, or None. If not provided, the task will be taken from
        `loop`.
      context_lengths: List of lengths of the context sequence fed into the
        model before starting prediction.
      horizon_lengths: List of lengths of the predicted sequence.
      n_steps: Number of batches to run evaluation for.
      accelerate_model (bool): Whether to wrap the model in `tl.Accelerate`.
    """
        super().__init__(loop)

        if model is None:
            model = loop.model

        observation_serializer = model.observation_serializer
        action_serializer = model.action_serializer

        predict_model = model.make_predict_model()
        if accelerate_model:
            predict_model = tl.Accelerate(predict_model)
        self._predict_model = predict_model
        self._obs_serializer = observation_serializer
        self._act_serializer = action_serializer

        if isinstance(eval_at, int):
            self._eval_at = lambda step: step % eval_at == 1
        elif hasattr(eval_at, '__in__'):
            self._eval_at = lambda step: step in eval_at
        elif callable(eval_at):
            self._eval_at = eval_at
        else:
            raise TypeError(f'Unsupported type for eval_at: {type(eval_at)}.')

        if eval_task is None:
            if len(loop.eval_tasks) != 1:
                raise ValueError(
                    'If eval_task is not provided, the number of eval_tasks registered '
                    'in Loop must be exactly 1.')
            eval_task = loop.eval_tasks[0]
        self._eval_task = eval_task

        self._context_lengths = list(sorted(context_lengths))
        self._horizon_lengths = list(sorted(horizon_lengths))
        self._n_steps = n_steps

        self._batch_size = eval_task.sample_batch[0].shape[0]
        (_, self._init_state) = predict_model.init(
            shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32))