Ejemplo n.º 1
0
 def test_trains_on_two_tasks(self):
   """Trains a very simple network on two very simple tasks."""
   model = tl.Serial(tl.Dense(3), tl.Dense(1))
   task = training.TrainTask(
       _very_simple_data(),
       tl.L2Loss(),
       optimizers.SGD(.01)
   )
   eval_task = training.EvalTask(
       _very_simple_data(),  # deliberately re-using training data
       [tl.L2Loss()],
   )
   training_session = training.Loop(
       model,
       tasks=(task, task),
       eval_tasks=(eval_task, eval_task),
       which_task=lambda step_n: step_n % 2,
   )
   self.assertEqual(0, training_session.step)
   training_session.run(n_steps=15)
   self.assertEqual(15, training_session.step)
   training_session.run(n_steps=5)
   self.assertEqual(20, training_session.step)
Ejemplo n.º 2
0
    def test_train_save_restore_dense(self):
        """Saves and restores a checkpoint to check for equivalence."""
        task = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(1))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        model, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint()
        model2, training_session2 = _make_model_and_session()

        x = np.ones((8, 1))
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertEqual(str(y1), str(y2))

        training_session2.run(n_steps=1)
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertNotEqual(str(y1), str(y2))

        slots1 = training_session._trainer_per_task[0].slots
        slots2 = training_session2._trainer_per_task[0].slots
        np.testing.assert_array_equal(slots1, slots2)
Ejemplo n.º 3
0
    def test_train_memory_efficient(self):
        """Trains a large network in a memory-efficient way."""
        # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.device_count() == 1 and ram_limited:
            return

        # Create the model.
        n_layers = 16  # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram
        model = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup(), [[
            tl.ReversibleHalfResidual(tl.Dense(16 * 1024)),
            tl.ReversibleSwap()
        ] for _ in range(n_layers)], tl.Concatenate(), tl.Dense(9),
                          tl.LogSoftmax())

        # Create inputs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))

        def _data_gen():
            while True:
                yield labeled_batch

        # Run training.
        task = training.TrainTask(_data_gen(), tl.CrossEntropyLoss(),
                                  optimizers.Adafactor)
        eval_task = training.EvalTask(_data_gen(), [tl.CrossEntropyLoss()])
        loop = training.Loop(model, [task],
                             eval_tasks=[eval_task],
                             eval_at=lambda step_n: step_n == 2,
                             use_memory_efficient_trainer=True)
        self.assertEqual(0, loop.step)
        loop.run(n_steps=2)
        self.assertEqual(2, loop.step)
Ejemplo n.º 4
0
    def test_loop_checkpoint_high_metric(self):
        """Runs a training loop that saves checkpoints for high metric values."""
        model = tl.Serial(tl.Dense(1))
        task = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                                  optimizers.SGD(.01))
        eval_metric = tl.L2Loss()
        eval_task = training.EvalTask(_very_simple_data(), [eval_metric],
                                      metric_names=['l2_loss'])
        tmp_dir = self.create_tempdir().full_path
        loop = training.Loop(model, [task],
                             eval_tasks=[eval_task],
                             output_dir=tmp_dir,
                             eval_at=lambda step_n: step_n % 2 == 0,
                             checkpoint_at=lambda step_n: step_n % 2 == 0,
                             checkpoint_high_metric='l2_loss')
        call_counter = collections.Counter()
        loop.save_checkpoint = lambda name: call_counter.update([name])
        loop.run(n_steps=10)

        # Eval metric steadily descends, so high checkpoint triggered only once.
        # Low checkpoint not defined, so never triggered.
        self.assertEqual(call_counter['model'], 5)
        self.assertEqual(call_counter['lowest_l2_loss'], 0)
        self.assertEqual(call_counter['highest_l2_loss'], 1)
Ejemplo n.º 5
0
    return model


model = NMTAttn()
# print(model)

train_task = training.TrainTask(
    labeled_data=train_batch_data,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
    lr_schedule=trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
    n_steps_per_checkpoint=20,
)

eval_task = training.EvalTask(
    labeled_data=eval_batch_data,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)

output_dir = 'Nueral_Machine_Translation_With_Attention/output_dir/'
model_file_path = os.path.join(output_dir,"model.pkl.gz")
# # remove old model if it exists. restarts training.
if os.path.exists(model_file_path):
    os.remove(model_file_path)

# define the training loop
training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

training_loop.run(3)
Ejemplo n.º 6
0
def train_model(model,
                data_generator,
                batch_size=32,
                max_length=64,
                lines=lines,
                eval_lines=eval_lines,
                n_steps=1,
                output_dir='model/'):
    """Function that trains the model

    Args:
        model (trax.layers.combinators.Serial): GRU model.
        data_generator (function): Data generator function.
        batch_size (int, optional): Number of lines per batch. Defaults to 32.
        max_length (int, optional): Maximum length allowed for a line to be processed. Defaults to 64.
        lines (list, optional): List of lines to use for training. Defaults to lines.
        eval_lines (list, optional): List of lines to use for evaluation. Defaults to eval_lines.
        n_steps (int, optional): Number of steps to train. Defaults to 1.
        output_dir (str, optional): Relative path of directory to save model. Defaults to "model/".

    Returns:
        trax.supervised.training.Loop: Training loop for the model.
    """

    ### START CODE HERE (Replace instances of 'None' with your code) ###
    bare_train_generator = data_generator(batch_size=batch_size,
                                          max_length=max_length,
                                          data_lines=lines)
    infinite_train_generator = itertools.cycle(bare_train_generator)

    bare_eval_generator = data_generator(batch_size=batch_size,
                                         max_length=max_length,
                                         data_lines=eval_lines)
    infinite_eval_generator = itertools.cycle(bare_eval_generator)

    train_task = training.TrainTask(
        labeled_data=
        infinite_train_generator,  # Use infinite train data generator
        loss_layer=tl.CrossEntropyLoss(
        ),  # Don't forget to instantiate this object
        optimizer=trax.optimizers.Adam(
            0.0005)  # Don't forget to add the learning rate parameter
    )

    eval_task = training.EvalTask(
        labeled_data=infinite_eval_generator,  # Use infinite eval data generator
        metrics=[tl.CrossEntropyLoss(),
                 tl.Accuracy()],  # Don't forget to instantiate these objects
        n_eval_batches=3  # For better evaluation accuracy in reasonable time
    )

    training_loop = training.Loop(model,
                                  train_task,
                                  eval_task=eval_task,
                                  output_dir=output_dir)

    training_loop.run(n_steps=n_steps)

    ### END CODE HERE ###

    # We return this because it contains a handle to the model, which has the weights etc.
    return training_loop
Ejemplo n.º 7
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.Serial(tl.LogSoftmax(),
                            tl.CrossEntropyLoss(),
                            name='CrossEntropyLoss'),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
    if use_loop:
        n_devices = num_devices() or fastmath.device_count()

        # Prepare the training task.
        # Inputs is either an Inputs instance or a function that returns it.
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            inputs = inputs()
        opt = optimizer if use_memory_efficient_trainer else optimizer()
        train_task = training.TrainTask(inputs.train_stream(n_devices),
                                        loss_layer=loss_fn,
                                        optimizer=opt,
                                        lr_schedule=lr_schedule_fn(),
                                        n_steps_per_checkpoint=eval_frequency)

        # Prepare the evaluation.
        metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        names, metrics = zip(*metrics_dict.items())
        eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                      metrics,
                                      metric_names=names,
                                      n_eval_batches=eval_steps)

        # Prepare the training loop.
        checkpoint_at = None
        if checkpoints_at is not None:
            checkpoint_at = lambda step: step in checkpoints_at
        loop = training.Loop(
            model(mode='train'), [train_task],
            eval_model=model(mode='eval'),
            eval_tasks=[eval_task],
            output_dir=output_dir,
            checkpoint_at=checkpoint_at,
            n_devices=n_devices,
            loss_chunk_size=loss_chunk_size,
            use_memory_efficient_trainer=use_memory_efficient_trainer,
            random_seed=random_seed)

        steps_to_go = steps - loop.step
        if steps_to_go <= 0:
            log('Stop training, already reached the total training steps %d' %
                steps)
            return loop

        # Train and return the loop.
        loop.run(steps_to_go)
        return loop

    n_devices = num_devices()
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule_fn(),
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            metrics=metrics,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
                    trainer.save_computation_graphs()

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state
Ejemplo n.º 8
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.WeightedCategoryCrossEntropy(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          permanent_checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          permanent_checkpoint_frequency=None,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False,
          adasum=False,
          init_checkpoint=None,
          callbacks=None,
          additional_train_tasks=None,
          additional_eval_tasks=None,
          additional_eval_streams=None):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    permanent_checkpoints_at: list of integers. Save a permanent checkpoint for
      each training step in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    permanent_checkpoint_frequency: int, how often to save permanent checkpoints
      (every permanent_checkpoint_frequency steps).
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.
    adasum: if True, use adaptive summation for multi-device gradients.
    init_checkpoint: a checkpoint for fine tuning.
    callbacks: a list of callbacks to call during training.
    additional_train_tasks: additional tasks which should be performed during
      training.
    additional_eval_tasks: additional tasks which should be performed during
      evaluation.
    additional_eval_streams: List[NamedStream], additional data streams that
      should be used during evaluation. Can be provided independently of
      additional_eval_tasks.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
  if (permanent_checkpoint_frequency is not None
      and permanent_checkpoints_at is not None):
    raise ValueError('Only one of ["permanent_checkpoint_frequency", '
                     '"permanent_checkpoints_at"] should be set.')
  if use_loop:
    n_devices = num_devices() or fastmath.local_device_count()

    # Prepare the training task.
    # Inputs is either an Inputs instance or a function that returns it.
    if callable(inputs):  # If we pass a function, e.g., through gin, call it.
      inputs = inputs()
    opt = optimizer if use_memory_efficient_trainer else optimizer()
    train_task = training.TrainTask(
        inputs.train_stream(n_devices),
        loss_layer=loss_fn,
        optimizer=opt,
        lr_schedule=lr_schedule_fn(),
        n_steps_per_checkpoint=eval_frequency,
        n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency)

    if additional_train_tasks is None:
      additional_train_tasks = []

    # Prepare the evaluation.
    metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
    names, metrics = zip(*metrics_dict.items())
    eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                  metrics,
                                  metric_names=names,
                                  n_eval_batches=eval_steps)

    if additional_eval_tasks is None:
      additional_eval_tasks = []

    additional_eval_tasks_from_streams = []
    if additional_eval_streams is not None:
      for stream in additional_eval_streams:
        additional_eval_tasks_from_streams.append(
            training.EvalTask(stream.stream,
                              metrics,
                              metric_names=names,
                              n_eval_batches=eval_steps,
                              export_prefix=stream.name))

    # Prepare the training loop.
    checkpoint_at = None
    if checkpoints_at is not None:
      checkpoint_at = lambda step: step in checkpoints_at
    permanent_checkpoint_at = None
    if permanent_checkpoints_at is not None:
      permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')
    if init_checkpoint:
      model_train.init_from_file(init_checkpoint, weights_only=True)
      model_predict_eval.init_from_file(init_checkpoint, weights_only=True)
    loop = training.Loop(
        model_train, [train_task] + additional_train_tasks,
        eval_model=model_predict_eval,
        eval_tasks=[eval_task] +
        additional_eval_tasks + additional_eval_tasks_from_streams,
        output_dir=output_dir,
        checkpoint_at=checkpoint_at,
        permanent_checkpoint_at=permanent_checkpoint_at,
        n_devices=n_devices,
        loss_chunk_size=loss_chunk_size,
        use_memory_efficient_trainer=use_memory_efficient_trainer,
        adasum=adasum,
        random_seed=random_seed,
        callbacks=callbacks,
    )

    steps_to_go = steps - loop.step
    if steps_to_go <= 0:
      log('Stop training, already reached the total training steps %d' % steps)
      return loop

    # Train and return the loop.
    loop.run(steps_to_go)
    return loop

  n_devices = num_devices()
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs,
                          output_dir,
                          random_seed=random_seed,
                          n_devices=n_devices,
                          checkpoints_at=checkpoints_at,
                          metrics=metrics,
                          checkpoint_lowest=checkpoint_lowest,
                          checkpoint_highest=checkpoint_highest,
                          init_checkpoint=init_checkpoint)

  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
  trainer.print_n_weights()

  try:
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
      trainer.train_epoch(epoch_steps, eval_steps)

      # Bookkeeping we do at the first step
      if trainer.step == 1:
        # Save computation graph (single-device only for now)
        if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
          trainer.save_computation_graphs()

        # Save Gin config
        trainer.save_gin()

    trainer.log_step('Training done')
  except Exception as e:
    raise e
  finally:
    trainer.close()
  return trainer.state
Ejemplo n.º 9
0
    def train(self,
              callbacks,
              epochs,
              loss,
              metrics,
              metric_emit_freq,
              optimizer,
              save_directory,
              output_type='infer',
              writer=stdout,
              n_eval_batches=10,
              batch_size_per_device=256,
              eval_batch_size=256,
              variable_shapes=False,
              *args,
              **kwargs):
        """
        Run the training loop for your ML pipeline.

        :param callbacks: Collection of callables that are run inside the training loop
        :type callbacks: ```None or List[Callable] or Tuple[Callable]```

        :param epochs: number of epochs (must be greater than 0)
        :type epochs: ```int```

        :param loss: Loss function, can be a string (depending on the framework) or an instance of a class
        :type loss: ```str or Callable or Any```

        :param metrics: Collection of metrics to monitor, e.g., accuracy, f1
        :type metrics: ```None or List[Callable or str] or Tuple[Callable or str]```

        :param metric_emit_freq: Frequency of metric emission, e.g., `lambda: epochs % 10 == 0`, defaults to every epoch
        :type metric_emit_freq: ```None or (*args, **kwargs) -> bool```

        :param optimizer: Optimizer, can be a string (depending on the framework) or an instance of a class
        :type callbacks: ```str or Callable or Any```

        :param save_directory: Directory to save output in, e.g., weights in h5 files. If None, don't save.
        :type save_directory: ```None or str```

        :param output_type: `if save_directory is not None` then save in this format, e.g., 'h5'.
        :type output_type: ```str```

        :param writer: Writer for all output, could be a TensorBoard instance, a file handler like stdout or stderr
        :type writer: ```stdout or Any```

        :param n_eval_batches:
        :type n_eval_batches: ```int```

        :param batch_size_per_device:
        :type batch_size_per_device: ```int```

        :param eval_batch_size:
        :type eval_batch_size: ```int```

        :param variable_shapes:
        :type variable_shapes: ```bool```

        :param args:
        :param kwargs:
        :return:
        """
        super(TraxTrainer, self).train(callbacks=callbacks,
                                       epochs=epochs,
                                       loss=loss,
                                       metrics=metrics,
                                       metric_emit_freq=metric_emit_freq,
                                       optimizer=optimizer,
                                       save_directory=save_directory,
                                       output_type='infer',
                                       writer=writer,
                                       *args,
                                       **kwargs)
        assert self.data is not None
        assert self.model is not None

        task = training.TrainTask(itertools.cycle(self.data.train_stream(1)),
                                  loss, optimizer)

        eval_task = training.EvalTask(itertools.cycle(
            self.data.eval_stream(1)),
                                      metrics,
                                      n_eval_batches=n_eval_batches)

        training_session = training.Loop(self.model,
                                         task,
                                         eval_task=eval_task,
                                         eval_at=metric_emit_freq)

        training_session.run(n_steps=epochs)
        return training_session
Ejemplo n.º 10
0
    train_data_pipeline = trax.data.Serial(
        trax.data.Shuffle(),
        trax.data.Batch(8),
    )
    train_batches_stream = train_data_pipeline(train_stream)

    eval_data_pipeline = trax.data.Batch(1)
    eval_batches_stream = eval_data_pipeline(eval_stream)

    # Define Train and Eval tasks using Trax Training
    train_task = training.TrainTask(
        labeled_data=train_batches_stream,
        loss_layer=tl.CategoryCrossEntropy(),
        optimizer=trax.optimizers.Adam(args.learning_rate),
    )

    eval_task = training.EvalTask(
        labeled_data=eval_batches_stream,
        metrics=[tl.CategoryCrossEntropy(),
                 tl.CategoryAccuracy()],
        n_eval_batches=20,
    )

    # Train Model
    model = get_model(n_output_classes=10)
    training_loop = training.Loop(model, train_task, eval_tasks=[eval_task])
    training_loop.run(args.train_steps)

    # Save Model
    save_model_tf(model)
Ejemplo n.º 11
0
Archivo: train.py Proyecto: JEF1056/R5
    print(f"(device count, tokens per device) = {test.shape}\n")
del teststream, test

# Training task.
train_task = training.TrainTask(
    labeled_data=stream(trax.fastmath.device_count(), "train"),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    lr_schedule=trax.lr.multifactor(),
    optimizer=trax.optimizers.Adam(),
    n_steps_per_checkpoint=1000,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=stream(trax.fastmath.device_count(), "validation"),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=10  # For less variance in eval numbers.
)

output_dir = os.path.expanduser(args.dir)

print("~~Begin Training~~")
# Train tiny model with Loop.
training_loop = training.Loop(
    trax.models.ReformerLM(mode="train"),
    train_task,
    eval_tasks=[eval_task],
    output_dir=output_dir)

# run 1000 steps (batches)
training_loop.run(1000000)