# **Implementing the computations using Serial combinator.**

# In[4]:

# Serial combinator
serial = tl.Serial(
    Addition(),
    Multiplication(),
    Addition()  # add 3 + 4  # multiply result by 15
)

# Initialization
x = (np.array([3]), np.array([4]), np.array([15]), np.array([3]))  # input

serial.init(shapes.signature(x))  # initializing serial instance

print("-- Serial Model --")
print(serial, "\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
Exemple #2
0
 def test_new_weights_no_bias(self):
     layer = tl.Dense(4, use_bias=False)
     x = np.array([1, 2])
     _, _ = layer.init(shapes.signature(x))
     self.assertEqual(layer.weights.shape, (2, 4))
Exemple #3
0
    def test_reformer2_one_step(self):
        vocab_size = 32
        max_len = 256
        pos_axial = 16
        assert pos_axial * pos_axial == max_len

        chunk_len = 32

        # Since 2 * chunk_len * n_buckets should be max_len.
        n_buckets = max_len // (2 * chunk_len)

        lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                               chunk_len=chunk_len,
                                               n_buckets=n_buckets)

        timebin_self_attention = self._timebin_self_attention_fn()

        model = reformer.Reformer2(
            vocab_size,
            d_model=32,
            d_ff=64,
            d_attention_key=64,
            d_attention_value=64,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_heads=2,
            dropout=0.05,
            max_len=max_len,
            encoder_attention_type=lsh_self_attention,
            encoder_decoder_attention_type=[
                timebin_self_attention, lsh_self_attention
            ],
            pos_axial_shape=(pos_axial, pos_axial),
            pos_d_axial_embs=(64, 192),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            ff_chunk_size=64,
            ff_sparsity=8,
            mode='train',
        )

        x = [
            np.ones((1, max_len)).astype(np.int32),
            np.ones((1, max_len)).astype(np.int32)
        ]
        weights, state = model.init(shapes.signature(x))

        @fastmath.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits_and_dec_toks, new_state = model.pure_fn(
                    x, weights, state, rng)
                # This returns [logits, decoder tokens]
                logits = logits_and_dec_toks[0]
                loss = fastmath.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state,
                        logits) = fastmath.grad(compute_mock_loss,
                                                has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        weights, state, logits = mock_training_step(
            x, weights, state, fastmath.random.get_prng(0))

        self.assertEqual(logits.shape, (1, max_len, vocab_size))
Exemple #4
0
  def __init__(self, task,
               value_body=None,
               value_optimizer=None,
               value_lr_schedule=lr.multifactor,
               value_batch_size=64,
               value_train_steps_per_epoch=500,
               value_evals_per_epoch=1,
               value_eval_steps=1,
               exploration_rate=functools.partial(
                   lr.multifactor,
                   factors='constant * decay_every',
                   constant=1.,  # pylint: disable=redefined-outer-name
                   decay_factor=0.99,
                   steps_per_decay=1,
                   minimum=0.1),
               n_eval_episodes=0,
               only_eval=False,
               n_replay_epochs=1,
               max_slice_length=1,
               sync_freq=1000,
               scale_value_targets=True,
               output_dir=None,
               **kwargs):
    """Configures the value trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      value_body: Trax layer, representing the body of the value model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      value_optimizer: the optimizer to use to train the policy model.
      value_lr_schedule: learning rate schedule to use to train the policy.
      value_batch_size: batch size used to train the policy model.
      value_train_steps_per_epoch: how long to train policy in each RL epoch.
      value_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      value_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      exploration_rate: exploration rate schedule - used in the policy method.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      max_slice_length: the maximum length of trajectory slices to use; it is
          the second dimenions of the value network output:
          (batch, max_slice_length, number of actions)
          Higher max_slice_length implies that the network has to predict more
          values into the future.
      sync_freq: frequency when to synchronize the target
        network with the trained network. This is necessary for training the
        network on bootstrapped targets, e.g. using n-step returns.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`. We are trying to fix the problem with very large
          returns in some games in a way which does not introduce an additional
          hyperparameters.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
    super(ValueAgent, self).__init__(
        task,
        n_eval_episodes=n_eval_episodes,
        output_dir=output_dir,
        **kwargs
    )
    self._value_batch_size = value_batch_size
    self._value_train_steps_per_epoch = value_train_steps_per_epoch
    self._value_evals_per_epoch = value_evals_per_epoch
    self._value_eval_steps = value_eval_steps
    self._only_eval = only_eval
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)
    self._n_replay_epochs = n_replay_epochs

    self._exploration_rate = exploration_rate()
    self._sync_at = (lambda step: step % sync_freq == 0)

    if scale_value_targets:
      self._value_network_scale = 1 / (1 - self._task.gamma)
    else:
      self._value_network_scale = 1

    value_model = functools.partial(
        models.Quality,
        body=value_body,
        n_actions=self.task.action_space.n)

    self._value_eval_model = value_model(mode='eval')
    self._value_eval_model.init(self._value_model_signature)
    self._value_eval_jit = tl.jit_forward(
        self._value_eval_model.pure_fn, fastmath.device_count(), do_mean=False)

    # Inputs to the value model are produced by self._values_batches_stream.
    self._inputs = data.inputs.Inputs(
        train_stream=lambda _: self.value_batches_stream())

    # This is the value Trainer that will be used to train the value model.
    # * inputs to the trainer come from self.value_batches_stream
    # * outputs, targets and weights are passed to self.value_loss
    self._value_trainer = supervised.Trainer(
        model=value_model,
        optimizer=value_optimizer,
        lr_schedule=value_lr_schedule(),
        loss_fn=self.value_loss,
        inputs=self._inputs,
        output_dir=output_dir,
        metrics={'value_loss': self.value_loss,
                 'value_mean': self.value_mean,
                 'returns_mean': self.returns_mean}
    )
    value_batch = next(self.value_batches_stream())
    self._eval_model = tl.Accelerate(
        value_model(mode='collect'), n_devices=1)
    self._eval_model.init(shapes.signature(value_batch))
    if self._task._initial_trajectories == 0:
      self._task.remove_epoch(0)
      self._collect_trajectories()
Exemple #5
0
 def get_output(x, mask):
     xs = [x, mask]
     _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0))
     return layer(xs, rng=jax.random.PRNGKey(1))
Exemple #6
0
 def test_shared_weights_double_nested(self):
     layer = tl.Dense(5)
     model = tl.Serial(tl.Serial(layer), tl.Serial(layer))
     sample_input = np.array([1, 2, 3, 4, 5])
     weights, _ = model.init(shapes.signature(sample_input))
     self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE)
Exemple #7
0
 def init_and_run(layer, xs):
     layer.init(shapes.signature(xs))
     layer(xs)
Exemple #8
0
 def test_forward_shape(self):
     layer = tl.LayerNorm()
     x = np.ones((3, 2, 7)).astype(np.float32)
     _, _ = layer.init(shapes.signature(x))
     y = layer(x)
     self.assertEqual(y.shape, x.shape)
Exemple #9
0
  def test_reformer2_one_step(self):
    d_model = 1024
    vocab_size = 14041
    max_len = 16384
    pos_axial = (128, 128)  # should multiply to max_len
    pos_d_axial_embs = (512, 512)  # sum to d model

    assert operator.mul(*pos_axial) == max_len
    assert sum(pos_d_axial_embs) == d_model

    d_ff = 4096
    n_heads = 8
    d_attn = d_model // n_heads

    n_buckets = 128
    encoder_chunk_len = (2 * max_len) // n_buckets  # 256
    decoder_chunk_len = 2 * encoder_chunk_len       # 512
    encoder_n_chunks_after = 1                      # since its not causal.

    lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                           n_buckets=n_buckets)

    encoder_lsh_self_attention = functools.partial(
        lsh_self_attention, n_chunks_after=encoder_n_chunks_after,
        chunk_len=encoder_chunk_len)

    decoder_lsh_self_attention = functools.partial(
        lsh_self_attention, n_chunks_after=0,
        chunk_len=decoder_chunk_len)

    model = reformer.Reformer2(
        vocab_size,
        d_model=d_model,
        d_ff=d_ff,
        d_attention_key=d_attn,
        d_attention_value=d_attn,
        n_encoder_layers=1,
        n_decoder_layers=1,
        n_heads=n_heads,
        dropout=0.05,
        max_len=max_len,
        encoder_attention_type=encoder_lsh_self_attention,
        encoder_decoder_attention_type=decoder_lsh_self_attention,
        pos_axial_shape=pos_axial,
        pos_d_axial_embs=pos_d_axial_embs,
        ff_activation=tl.Relu,
        ff_use_sru=0,
        mode='train',
    )

    def random_sentence():
      return np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len),
                               dtype=np.int32)

    x = [random_sentence(), random_sentence()]
    weights, state = model.init(shapes.signature(x))

    @fastmath.jit
    def mock_training_step(x, weights, state, rng):
      def compute_mock_loss(weights):
        logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng)
        # This returns [logits, decoder tokens]
        logits = logits_and_dec_toks[0]
        loss = fastmath.numpy.mean(logits[..., 0])
        return loss, (new_state, logits)
      gradients, (new_state, logits) = fastmath.grad(
          compute_mock_loss, has_aux=True)(weights)
      new_weights = fastmath.nested_map_multiarg(
          lambda w, g: w - 1e-4 * g, weights, gradients)
      return new_weights, new_state, logits

    weights, state, logits = mock_training_step(
        x, weights, state, fastmath.random.get_prng(0))

    self.assertEqual(logits.shape, (1, max_len, vocab_size))
Exemple #10
0
def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None,
                             init_tokens=3, message=''):
  """Utility method for testing equivalence of predict and eval modes.

  Args:
    inp: input fed to the model. It can be a tensor, or a tuple of tensors.
    model_fn: function creating a model after calling with `mode` argument.
    seq_axis: axis of sequence_length. In predict mode we iterate over this
      axis. By default `1`, which is 2nd dimension.
    seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor
      in this tuple on which we iterate the sequence.
    init_tokens: how many tokens should be passed to the first `predict` call.
    message: Optional message to show when outputs of eval/predict mode don't
      match.
  """
  with fastmath.use_backend(fastmath.Backend.JAX):
    model_eval = model_fn(mode='eval')
    model_predict = model_fn(mode='predict')

    input_signature = shapes.signature(inp)
    model_eval.init(input_signature)
    model_predict.init(input_signature)
    model_eval.save_to_file('/tmp/unique_weights')
    model_predict.init_from_file('/tmp/unique_weights', weights_only=True,
                                 input_signature=input_signature)

    rng = fastmath.random.get_prng(0)
    output_eval = model_eval(inp, rng=rng)
    if not isinstance(output_eval, (tuple, list)):
      # We will automatically check each and every tensor returned.
      output_eval = [output_eval]

    if seq_tensor is None:
      length = inp.shape[seq_axis]
    else:
      length = inp[seq_tensor].shape[seq_axis]

    assert length >= init_tokens + 2  # Required to properly test predict mode.
    indices_list = [(0, init_tokens)] + [(i, i+1)
                                         for i in range(init_tokens, length)]

    for indices in indices_list:
      start, end = indices
      if seq_tensor is None:
        new_inp = inp.take(indices=range(start, end), axis=seq_axis)
      else:
        new_inp = list(inp)
        new_inp[seq_tensor] = new_inp[seq_tensor].take(
            indices=range(start, end), axis=seq_axis)

      output_predict = model_predict(new_inp, rng=rng)
      if not isinstance(output_predict, (tuple, list)):
        # We will automatically check each and every tensor returned.
        output_predict = [output_predict]

      np.testing.assert_equal(len(output_predict), len(output_eval))
      for outp, oute in zip(output_predict, output_eval):
        np.testing.assert_array_almost_equal(
            oute.take(indices=range(start, end), axis=seq_axis),
            outp.take(indices=range(0, end-start), axis=seq_axis),
            decimal=5,
            err_msg='Error on element {} out of {}.{}'.format(indices, length,
                                                              message))
Exemple #11
0
 def _value_model_signature(self):
   obs_sig = shapes.signature(self._task.observation_space)
   target_sig = mask_sig = shapes.ShapeDtype(
       shape=(1, 1, 1),
   )
   return (obs_sig.replace(shape=(1, 1) + obs_sig.shape), target_sig, mask_sig)
Exemple #12
0
    def pure_fn(self, x, weights, state, rng, use_cache=False):
        """Applies this layer as a pure function with no optional args.

    This method exposes the layer's computation as a pure function. This is
    especially useful for JIT compilation. Do not override, use `forward`
    instead.

    Args:
      x: Zero or more input tensors, packaged as described in the `Layer` class
          docstring.
      weights: A tuple or list of trainable weights, with one element for this
          layer if this layer has no sublayers, or one for each sublayer if
          this layer has sublayers. If a layer (or sublayer) has no trainable
          weights, the corresponding weights element is an empty tuple.
      state: Layer-specific non-parameter state that can update between batches.
      rng: Single-use random number generator (JAX PRNG key).
      use_cache: if `True`, cache weights and state in the layer object; used
        to implement layer sharing in combinators.

    Returns:
      A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
      promised by this layer, and are packaged as described in the `Layer`
      class docstring.
    """
        try:
            old_weights, old_state, old_rng = self.weights, self.state, self.rng
            self._rng = rng
            # The isinstance check is only needed when == is overloaded, as in TF.
            if (isinstance(weights, dict) and isinstance(state, dict)
                    and weights == GET_WEIGHTS_FROM_CACHE
                    and state == GET_STATE_FROM_CACHE):
                was_cached = True
                weights = self.weights
                state = self.state
            else:
                # In this case, we're called for the first time: cache weights.
                was_cached = False
                self.weights, self.state = weights, state

            # If weights are sharded across multiple devices, unshard before forward.
            sharded_weights, weights_were_unsharded = weights, False
            if N_WEIGHTS_SHARDS > 1 and not self.sublayers:
                self.weights, weights_were_unsharded = unshard_in_pmap(
                    weights, N_WEIGHTS_SHARDS)

            if not self.has_backward:
                outputs = self.forward(x)
                s = self.state
            else:
                outputs, s = self._do_custom_gradients(x)
                self.state = s
            self._rng = old_rng
            if weights_were_unsharded:  # only store a shard of weights if sharded
                self.weights = sharded_weights

            if not use_cache:
                self.weights, self.state = old_weights, old_state
            if was_cached:  # If the layer was shared, return a state marking this.
                s = GET_STATE_FROM_CACHE
            return outputs, s

        except Exception:
            # Skipping 3 lines as it's always the uninteresting internal call.
            name, trace = self._name, _short_traceback(skip=3)
            raise LayerError(name, 'pure_fn', self._caller, signature(x),
                             trace) from None
Exemple #13
0
 def test_mlp_forward_shape(self):
     model = mlp.MLP(layer_widths=(32, 16, 8))
     x = np.ones((7, 28, 28, 3)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (7, 8))
Exemple #14
0
 def _check_forward_shape(self, model, input_shape, output_vocab_size):
     x = np.ones(input_shape).astype(np.int32)
     model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (*input_shape, output_vocab_size))
Exemple #15
0
 def abstract_f(*args, **kwargs):
     real_args = [
         nested_map(lambda x: np.zeros(x.shape, x.dtype), a) for a in args
     ]
     real_res = f(*real_args, **kwargs)
     return signature(real_res)
Exemple #16
0
 def test_value_forward_shape(self):
   model = rl.Value()
   x = np.ones((2, 3))
   _, _ = model.init(shapes.signature(x))
   y = model(x)
   self.assertEqual(y.shape, (2, 1))
Exemple #17
0
    def __init__(
        self,
        model,
        tasks,
        eval_model=None,
        eval_tasks=None,
        output_dir=None,
        checkpoint_at=None,
        eval_at=None,
        which_task=None,
        n_devices=None,
        random_seed=None,
        use_memory_efficient_trainer=False,
    ):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      tasks: List of TrainTask instances, which define the training data, loss
          function, and optimizer to be used in respective tasks in this
          training loop. It can also be a single TrainTask instance which is
          treated in the same way as a singleton list.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_tasks: List of EvalTask instances which define how to evaluate
        the model: which validation data to use and which metrics to report.
        Evaluation on each of the tasks and will run and be reported separately
        which allows to score a model on different subtasks. This argument can
        also be None, in which case no evals will be run, or a single
        EvalTask, which wil be treated in the same way as a singleton list.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
      which_task: Function (integer --> integer) indicating which task should be
          used at which training step. Can be set to None in single-task
          training.
      n_devices: integer or None, the number of devices for this computation.
      random_seed: the random seed to use; time/os dependent if None (default).
      use_memory_efficient_trainer: whether to use a special memory-efficient
        trainer
    """
        self._is_chief, self._n_hosts, self._n_devices, self._rng = (
            init_host_and_devices(n_devices, random_seed))

        # Handle single task case without lists too.
        if not isinstance(tasks, (list, tuple)):
            tasks = [tasks]

        if not tasks:
            raise ValueError('Must provide at least one training task.')
        if eval_tasks is None:
            eval_tasks = []
            eval_at = _never
        else:
            if not isinstance(eval_tasks, (list, tuple)):
                eval_tasks = [eval_tasks]

        self._tasks = tasks
        self._model = model
        self._eval_model = eval_model or model

        self._use_memory_efficient_trainer = use_memory_efficient_trainer
        # TODO(lukaszkaiser): can we have different eval models and save memory?
        if use_memory_efficient_trainer:
            assert len(tasks) == 1, 'only single task supported for now'
            assert len(eval_tasks) < 2, 'a most 1 eval task supported for now'
            self._eval_model = model

        default_at = _at_step_1_and_every_nth_step(
            tasks[0].n_steps_per_checkpoint)
        if output_dir is not None:
            self._output_dir = os.path.expanduser(output_dir)
            tf.io.gfile.makedirs(self._output_dir)
        else:
            self._output_dir = None

        # Prepare training components.
        self._step = 0
        self._checkpoint_at = checkpoint_at or default_at
        if which_task is None:
            if len(tasks) > 1:
                raise ValueError(
                    'Must provide which_task for multitask training.')
            which_task = lambda _: 0
        self._which_task = which_task

        # Initialize using the given random seed.
        # NOTE: If `random_seed` is `None` then `self._rng` will be different on
        # different hosts, leading to different weights on the different hosts.
        self._batch_signature = shapes.signature(tasks[0].sample_batch)
        self._model.rng = self.new_rng()
        # In the memory-efficient case, we initialize in init_trainer.
        if not use_memory_efficient_trainer:
            self._model.init(self._batch_signature)
            self._eval_model.rng = self.new_rng()
            self._eval_model.init(self._batch_signature)

        # To handle the above case (i.e. random_seed = None), we psum the weights
        # and state and average them.
        # NOTE: This adds time (how much?) so we prefer not to do it if it is
        # unnecessary, i.e. random_seed was set.
        # NOTE: Averaging the weights across devices can screw up the initial weight
        # statistics.
        # TODO(pkozakowski): Broadcast from one of the devices instead?
        if random_seed is None and self._n_hosts > 1:
            logging.info('Syncing weights/state across %d hosts.',
                         self._n_hosts)
            self._sync_weights_and_state_across_hosts()

        # Create the optimizer for the training loss function.
        self._trainer_per_task = tuple(
            self._init_trainer(task) for task in tasks)
        self.load_checkpoint()

        # Prepare eval components.
        self._eval_at = eval_at or default_at
        self._eval_tasks = eval_tasks
        loss_names = [task.loss_layer.name for task in self._tasks]
        metric_names = [
            name  # pylint: disable=g-complex-comprehension
            for eval_task in self._eval_tasks
            for name in eval_task.metric_names
        ]
        self._rjust_len = max(map(len, loss_names + metric_names))
        self._evaluator_per_task = tuple(
            self._init_evaluator(eval_task) for eval_task in self._eval_tasks)

        if self._output_dir is None:
            _log(
                'Will not write evaluation metrics, because output_dir is None.'
            )

        def task_output_dir(task_index, task_list):
            if self._output_dir is not None:
                if len(task_list) < 2:
                    output_dir = self._output_dir
                else:
                    output_dir = os.path.join(self._output_dir,
                                              str(task_index))
                tf.io.gfile.makedirs(output_dir)
                return output_dir
            else:
                return None

        self._output_dir_per_eval_task = [
            task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))
        ]
        self._output_dir_per_train_task = [
            task_output_dir(i, tasks) for i in range(len(tasks))
        ]
Exemple #18
0
 def test_signature_on_ndarray(self):
     array = onp.array([[2, 3, 5, 7], [11, 13, 17, 19]], dtype=onp.int16)
     sd = shapes.signature(array)
     self.assertEqual(sd.shape, (2, 4))
     self.assertEqual(sd.dtype, onp.int16)
Exemple #19
0
 def test_shared_weights_nested(self):
     layer = tl.Dense(5)
     model = tl.Parallel([layer, tl.Dense(2)], [layer, tl.Dense(2)])
     sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]))
     weights, _ = model.init(shapes.signature(sample_input))
     self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE)
Exemple #20
0
 def test_resnet(self):
     model = resnet.Resnet50(d_hidden=8, n_output_classes=10)
     x = np.ones((3, 256, 256, 3)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (3, 10))
Exemple #21
0
  def __init__(
      self, task, model_fn,
      optimizer=adam.Adam,
      lr_schedule=lr.multifactor,
      batch_size=64,
      network_eval_at=None,
      n_eval_batches=1,
      max_slice_length=1,
      **kwargs
  ):
    """Initializes PolicyGradient.

    Args:
      task: Instance of trax.rl.task.RLTask.
      model_fn: Function (policy_distribution, mode) -> policy_model.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      batch_size: Batch size for network training.
      network_eval_at: Function step -> bool indicating the training steps, when
        network evaluation should be performed.
      n_eval_batches: Number of batches to run during network evaluation.
      max_slice_length: The length of trajectory slices to run the network on.
      **kwargs: Keyword arguments passed to the superclass.
    """
    super().__init__(task, **kwargs)

    self._max_slice_length = max_slice_length
    trajectory_batch_stream = task.trajectory_batch_stream(
        batch_size,
        epochs=[-1],
        max_slice_length=self._max_slice_length,
        sample_trajectories_uniformly=True,
    )
    self._policy_dist = distributions.create_distribution(task.action_space)
    train_task = policy_tasks.PolicyTrainTask(
        trajectory_batch_stream,
        optimizer(),
        lr_schedule(),
        self._policy_dist,
        # Policy gradient uses the MC estimator. No need for margin - the MC
        # estimator only uses empirical returns.
        advantage_estimator=advantages.monte_carlo(task.gamma, margin=0),
        value_fn=self._value_fn,
    )
    eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches)
    model_fn = functools.partial(
        model_fn,
        policy_distribution=self._policy_dist,
    )

    if self._output_dir is not None:
      policy_output_dir = os.path.join(self._output_dir, 'policy')
    else:
      policy_output_dir = None
    # Checkpoint every epoch. We do one step per epoch, so that's every step.
    checkpoint_at = lambda _: True
    self._loop = supervised.training.Loop(
        model=model_fn(mode='train'),
        tasks=[train_task],
        eval_model=model_fn(mode='eval'),
        eval_tasks=[eval_task],
        output_dir=policy_output_dir,
        eval_at=network_eval_at,
        checkpoint_at=checkpoint_at,
    )
    self._collect_model = model_fn(mode='collect')
    self._collect_model.init(shapes.signature(train_task.sample_batch))

    # Validate the restored checkpoints. The number of network training steps
    # (self.loop.step) should be equal to the number of epochs (self._epoch),
    # because we do exactly one gradient step per epoch.
    # TODO(pkozakowski): Move this to the base class once all Agents use Loop.
    if self.loop.step != self._epoch:
      raise ValueError(
          'The number of Loop steps must equal the number of Agent epochs, '
          'got {} and {}.'.format(self.loop.step, self._epoch)
      )
Exemple #22
0
 def test_wide_resnet(self):
     model = resnet.WideResnet(n_blocks=1, n_output_classes=10)
     x = np.ones((3, 32, 32, 3)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (3, 10))
Exemple #23
0
 def get_output():
     _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0))
     return layer(x, rng=jax.random.PRNGKey(1))
        def _test_for_chunk_lens(rel_chunk_len, vanilla_chunk_len):
            d_model = 8
            vocab_size = 4
            batch_size = 1
            n_len_eval = 42
            attention_type = tl.SelfAttention

            shorten_factor = 3
            n_rel_layers = 2
            vanilla_layers = (1, 1)
            n_heads = 2

            eval_funnel = ft.RelformerLM(vocab_size,
                                         shorten_factor=shorten_factor,
                                         n_rel_layers=n_rel_layers,
                                         vanilla_layers=vanilla_layers,
                                         d_model=d_model,
                                         d_ff=d_model,
                                         n_heads=n_heads,
                                         vanilla_attn_type=attention_type,
                                         rel_chunk_len=rel_chunk_len,
                                         vanilla_chunk_len=vanilla_chunk_len,
                                         mode='eval')

            inputs = jax.random.randint(key=jax.random.PRNGKey(0),
                                        minval=0,
                                        maxval=vocab_size,
                                        shape=(batch_size,
                                               n_len_eval)).astype(np.int32)
            _, _ = eval_funnel.init(shapes.signature(inputs),
                                    rng=jax.random.PRNGKey(0))
            y_eval = eval_funnel(inputs)
            self.assertEqual(y_eval.shape,
                             (batch_size, n_len_eval, vocab_size))

            if attention_type == tl.SelfAttention:
                gin.bind_parameter('trax.layers.SelfAttention.chunk_len',
                                   n_len_eval)

            predict_funnel = ft.RelformerLM(
                vocab_size,
                shorten_factor=shorten_factor,
                n_rel_layers=n_rel_layers,
                vanilla_layers=vanilla_layers,
                d_model=d_model,
                d_ff=d_model,
                n_heads=n_heads,
                vanilla_attn_type=attention_type,
                rel_chunk_len=rel_chunk_len,
                vanilla_chunk_len=vanilla_chunk_len,
                mode='predict')

            inputs = np.concatenate(
                [np.zeros((batch_size, 1)).astype(np.int32), inputs], axis=1)
            inputs = inputs[:, :-1]

            _, _ = predict_funnel.init(shapes.signature(inputs[:, 0:1]),
                                       rng=jax.random.PRNGKey(0),
                                       use_cache=False)

            for i in range(n_len_eval):
                y = predict_funnel(inputs[:, i:i + 1])
                np.testing.assert_array_almost_equal(y,
                                                     y_eval[:, i:i + 1, :],
                                                     decimal=5)
Exemple #25
0
    def test_lsh_and_pure_lsh_self_attention_equivalence(self):
        # Given the same weight matrices and random numbers, do these produce the
        # same output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 4
            d_head = 4
            d_model = n_heads * d_head
            pure_lsh_layer = efficient_attention.PureLSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                bias=False,
                mode='train')
            lsh_layer = efficient_attention.LSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                mode='train')

            batch, seqlen = 3, 32
            input_shape = (batch, seqlen, d_model)

            x = jax.random.uniform(jax.random.PRNGKey(0),
                                   input_shape,
                                   dtype=jnp.float32)
            lsh_layer_input = x

            call_rng = jax.random.PRNGKey(42)

            lsh_layer_weights, lsh_layer_state = lsh_layer.init(
                shapes.signature(lsh_layer_input))
            lsh_layer.rng = call_rng
            lsh_layer_output = lsh_layer(lsh_layer_input)

            # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head),
            # (n_heads, d_head, d_model)
            # Abbreviated as - hmn, hmn, hnm
            w_qk, w_v, w_o = lsh_layer_weights

            qk = jnp.einsum('blm,hmn->bhln', x, w_qk)
            qk = qk.reshape((-1, qk.shape[2], qk.shape[3]))

            v = jnp.einsum('blm,hmn->bhln', x, w_v)
            v = v.reshape((-1, v.shape[2], v.shape[3]))

            pure_lsh_layer_input = (qk, v)
            _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input))
            pure_lsh_layer.rng = call_rng
            pure_lsh_layer.state = lsh_layer_state
            pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input)

            # b*h,l,n
            pure_lsh_layer_output = pure_lsh_layer_output.reshape(
                (batch, -1) + pure_lsh_layer_output.shape[1:])
            pure_lsh_layer_output_projected = (jnp.einsum(
                'bhld,hdm->blm', pure_lsh_layer_output, w_o))

            diff = pure_lsh_layer_output_projected - lsh_layer_output
            avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff))

            self.assertLess(avg_diff, 1e-5)
Exemple #26
0
 def test_simple_call(self):
     layer = tl.PositionalEncoding(max_len=8)
     x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]])
     layer.init(shapes.signature(x))
     y = layer(x)
     self.assertEqual(y.shape, (1, 2, 4))
Exemple #27
0
    def __init__(self,
                 task,
                 policy_model=None,
                 policy_optimizer=None,
                 policy_lr_schedule=lr.multifactor,
                 policy_batch_size=64,
                 policy_train_steps_per_epoch=500,
                 policy_evals_per_epoch=1,
                 policy_eval_steps=1,
                 n_eval_episodes=0,
                 only_eval=False,
                 max_slice_length=1,
                 output_dir=None,
                 **kwargs):
        """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super().__init__(task,
                         n_eval_episodes=n_eval_episodes,
                         output_dir=output_dir,
                         **kwargs)
        self._policy_batch_size = policy_batch_size
        self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
        self._policy_evals_per_epoch = policy_evals_per_epoch
        self._policy_eval_steps = policy_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)

        # Inputs to the policy model are produced by self._policy_batches_stream.
        self._policy_inputs = data.inputs.Inputs(
            train_stream=lambda _: self.policy_batches_stream())

        policy_model = functools.partial(
            policy_model,
            policy_distribution=self._policy_dist,
        )

        # This is the policy Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.policy_batches_stream
        # * outputs, targets and weights are passed to self.policy_loss
        self._policy_trainer = supervised.Trainer(
            model=policy_model,
            optimizer=policy_optimizer,
            lr_schedule=policy_lr_schedule(),
            loss_fn=self.policy_loss,
            inputs=self._policy_inputs,
            output_dir=output_dir,
            metrics=self.policy_metrics,
        )
        self._policy_collect_model = tl.Accelerate(
            policy_model(mode='collect'), n_devices=1)
        policy_batch = next(self.policy_batches_stream())
        self._policy_collect_model.init(shapes.signature(policy_batch))
        self._policy_eval_model = tl.Accelerate(
            policy_model(mode='eval'), n_devices=1)  # Not collecting stats
        self._policy_eval_model.init(shapes.signature(policy_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()
Exemple #28
0
    def __init__(self,
                 model,
                 task,
                 eval_model=None,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None,
                 eval_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
        self._task = task
        self._model = model
        self._eval_model = eval_model or model
        default_at = (_at_step_1_and_every_nth_step(
            self._task.n_steps_per_checkpoint))
        if output_dir is not None:
            self._output_dir = os.path.expanduser(output_dir)
            tf.io.gfile.makedirs(self._output_dir)
        else:
            self._output_dir = None

        # Prepare training components.
        self._step = 0
        self._checkpoint_at = checkpoint_at or default_at
        self._model_in_training = tl.Serial(self._model, self._task.loss_layer)
        self._batch_signature = shapes.signature(self._task.sample_batch)
        self._eval_model.init(self._batch_signature)
        self._model_in_training.init(self._batch_signature)
        self._task.optimizer.tree_init(self._model_in_training.weights)
        self._forward_and_backward_fn = (
            fastmath.jit(
                fastmath.value_and_grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (loss, state), gradients

        # Prepare eval components.
        if eval_task is None:
            self._eval_at = _never
        else:
            self._eval_task = eval_task
            self._eval_at = eval_at or default_at
            metric_name_lengths = [
                len(name) for name in self._eval_task.metric_names
            ]
            self._rjust_len = max([len(self._task.loss_layer.name)] +
                                  metric_name_lengths)
            model_with_metrics = (_model_with_metrics(self._eval_model,
                                                      self._eval_task))
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
            if self._output_dir is None:
                _log(
                    'Will not write evaluation metrics, because output_dir is None.'
                )
Exemple #29
0
 def test_grulm_forward_shape(self):
     model = rnn.GRULM(vocab_size=20, d_model=16)
     x = np.ones((3, 28)).astype(np.int32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (3, 28, 20))
Exemple #30
0
 def test_mlp_forward_shape(self):
     model = mlp.MLP(d_hidden=32, n_output_classes=10)
     x = np.ones((3, 28, 28, 1)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (3, 10))