Example #1
0
def BERTPretrainingLoss():
    nsp_loss = [
        tl.Select([0, 2, 3], n_in=6),
        tl.WeightedCategoryCrossEntropy()
    ]
    mlm_loss = [
        tl.Select([1, 4, 5], n_in=6),
        tl.WeightedCategoryCrossEntropy()
    ]
    return tl.Serial(tl.Branch(nsp_loss, mlm_loss), tl.Add())
Example #2
0
    def test_no_int32_or_uint32_returned(self):
        """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
        with fastmath.use_backend(fastmath.Backend.TFNP):
            n_classes = 1001
            model_fn = functools.partial(models.Resnet50,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.WeightedCategoryCrossEntropy(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )
            output_dir = self.create_tempdir().full_path
            trainer.reset(output_dir)
            trainer.train_epoch(1, 0)
            # Those are the things returned by Trainer._jit_update_fn
            arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                      trainer._model_state, trainer._rngs)
            arrays = tf.nest.flatten(arrays)
            for x in arrays:
                if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32
                                                   or x.dtype == jnp.uint32):
                    raise ValueError('Found an array of int32 or uint32: %s' %
                                     x)
Example #3
0
 def test_call_and_grad(self):
     layer_partial = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
     )
     layer = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
         tl.WeightedCategoryCrossEntropy(),
     )
     x = np.ones((1, 2), dtype=np.int32)
     w = np.ones_like(x).astype(np.float32)
     x_sig = shapes.signature(x)
     w_sig = shapes.signature(w)
     layer_partial.init(x_sig)
     y = layer_partial(x)
     self.assertEqual(y.shape, (1, 2, 4))
     layer.init((x_sig, x_sig, w_sig))
     y = layer((x, x, w))
     self.assertEqual(y.shape, ())
     state = layer.state
     rng = fastmath.random.get_prng(0)
     fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[
         0]
     g = fastmath.grad(fwd)(layer.weights, (x, x, w))
     self.assertEqual(g[0][1][0].shape, (3, 4))
Example #4
0
def _mnist_tasks(head=None):
    """Creates MNIST training and evaluation tasks.

  Args:
    head: Adaptor layer to put before loss and accuracy layers in the tasks.

  Returns:
    A pair (train_task, eval_task) consisting of the MNIST training task and the
    MNIST evaluation task using cross-entropy as loss and accuracy as metric.
  """
    loss = tl.WeightedCategoryCrossEntropy()
    accuracy = tl.WeightedCategoryAccuracy()
    if head is not None:
        loss = tl.Serial(head, loss)
        accuracy = tl.Serial(head, accuracy)
    task = training.TrainTask(
        itertools.cycle(_mnist_dataset().train_stream(1)),
        loss,
        adam.Adam(0.001),
    )
    eval_task = training.EvalTask(
        itertools.cycle(_mnist_dataset().eval_stream(1)),
        [loss, accuracy],
        n_eval_batches=10,
        metric_names=['CrossEntropy', 'WeightedCategoryAccuracy'],
    )
    return (task, eval_task)
Example #5
0
    def test_train_memory_efficient(self):
        """Trains a large network in a memory-efficient way."""
        # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.device_count() == 1 and ram_limited:
            return

        # Create the model.
        n_layers = 16  # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram
        model = tl.Serial(
            tl.Embedding(9, 16 * 1024),
            tl.Dup(),
            [[
                tl.ReversibleHalfResidual(tl.Dense(16 * 1024)),
                tl.ReversibleSwap()
            ] for _ in range(n_layers)],
            tl.Concatenate(),
            tl.Dense(9),
        )

        # Create inputs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))

        def _data_gen():
            while True:
                yield labeled_batch

        # Run training.
        loss_layer = tl.WeightedCategoryCrossEntropy()
        task = training.TrainTask(_data_gen(), loss_layer,
                                  optimizers.Adafactor)
        eval_task = training.EvalTask(_data_gen(),
                                      [tl.WeightedCategoryCrossEntropy()])
        loop = training.Loop(model, [task],
                             eval_tasks=[eval_task],
                             eval_at=lambda step_n: step_n == 2,
                             use_memory_efficient_trainer=True)
        self.assertEqual(0, loop.step)
        loop.run(n_steps=2)
        self.assertEqual(2, loop.step)
Example #6
0
    def test_weighted_category_cross_entropy(self):
        layer = tl.WeightedCategoryCrossEntropy()
        targets = np.array([0, 1])
        weights = np.array([30, 10])

        # Near-perfect prediction (for both items in batch).
        model_outputs = np.array([[9., 2., 0., -2.], [2., 9., 0., -2.]])
        loss = layer([model_outputs, targets, weights])
        self.assertAlmostEqual(loss, .001, places=3)

        # More right than wrong (for both items in batch).
        model_outputs = np.array([[2.2, 2., 0., -2.], [2., 2.2, 0., -2.]])
        loss = layer([model_outputs, targets, weights])
        self.assertAlmostEqual(loss, .665, places=3)

        # First item (with 75% weight) near perfect, second more right than wrong.
        model_outputs = np.array([[9., 2., 0., -2.], [2., 2.2, 0., -2.]])
        loss = layer([model_outputs, targets, weights])
        self.assertAlmostEqual(loss, .167, places=3)
Example #7
0
    def test_reset_twice(self, backend):
        with fastmath.use_backend(backend):
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         layer_widths=(16, 16, n_classes))
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.WeightedCategoryCrossEntropy(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            output_dir1 = self.create_tempdir(name='output_dir1').full_path
            trainer.reset(output_dir1)
            trainer.evaluate(1)
            output_dir2 = self.create_tempdir(name='output_dir2').full_path
            trainer.reset(output_dir2)
            trainer.evaluate(1)
Example #8
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.WeightedCategoryCrossEntropy(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          permanent_checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          permanent_checkpoint_frequency=None,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False,
          init_checkpoint=None,
          callbacks=None,
          additional_train_tasks=None,
          additional_eval_tasks=None):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    permanent_checkpoints_at: list of integers. Save a permanent checkpoint for
      each training step in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    permanent_checkpoint_frequency: int, how often to save permanent checkpoints
      (every permanent_checkpoint_frequency steps).
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer..
    init_checkpoint: a checkpoint for fine tuning.
    callbacks: a list of callbacks to call during training.
    additional_train_tasks: additional tasks which should be performed during
      training.
    additional_eval_tasks: additional tasks which should be performed during
      evaluation.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
    if (permanent_checkpoint_frequency is not None
            and permanent_checkpoints_at is not None):
        raise ValueError('Only one of ["permanent_checkpoint_frequency", '
                         '"permanent_checkpoints_at"] should be set.')
    if use_loop:
        n_devices = num_devices() or fastmath.device_count()

        # Prepare the training task.
        # Inputs is either an Inputs instance or a function that returns it.
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            inputs = inputs()
        opt = optimizer if use_memory_efficient_trainer else optimizer()
        train_task = training.TrainTask(
            inputs.train_stream(n_devices),
            loss_layer=loss_fn,
            optimizer=opt,
            lr_schedule=lr_schedule_fn(),
            n_steps_per_checkpoint=eval_frequency,
            n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency)

        # Prepare the evaluation.
        metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        names, metrics = zip(*metrics_dict.items())
        eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                      metrics,
                                      metric_names=names,
                                      n_eval_batches=eval_steps)

        # Prepare the training loop.
        checkpoint_at = None
        if checkpoints_at is not None:
            checkpoint_at = lambda step: step in checkpoints_at
        permanent_checkpoint_at = None
        if permanent_checkpoints_at is not None:
            permanent_checkpoint_at = (
                lambda step: step in permanent_checkpoints_at)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        if init_checkpoint:
            model_train.init_from_file(init_checkpoint, weights_only=True)
            model_predict_eval.init_from_file(init_checkpoint,
                                              weights_only=True)
        loop = training.Loop(
            model_train,
            [train_task] + (additional_train_tasks
                            if additional_train_tasks is not None else []),
            eval_model=model_predict_eval,
            eval_tasks=[eval_task] +
            (additional_eval_tasks
             if additional_eval_tasks is not None else []),
            output_dir=output_dir,
            checkpoint_at=checkpoint_at,
            permanent_checkpoint_at=permanent_checkpoint_at,
            n_devices=n_devices,
            loss_chunk_size=loss_chunk_size,
            use_memory_efficient_trainer=use_memory_efficient_trainer,
            random_seed=random_seed,
            callbacks=callbacks,
        )

        steps_to_go = steps - loop.step
        if steps_to_go <= 0:
            log('Stop training, already reached the total training steps %d' %
                steps)
            return loop

        # Train and return the loop.
        loop.run(steps_to_go)
        return loop

    n_devices = num_devices()
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule_fn(),
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            metrics=metrics,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest,
                            init_checkpoint=init_checkpoint)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
                    trainer.save_computation_graphs()

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state
Example #9
0
        'step',  # Current training step number.
        'opt_state',  # OptState.
        'history',  # trax.history.History.
        'model_state',  # Auxilliary state of the model.
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.WeightedCategoryCrossEntropy(),
    'accuracy': tl.WeightedCategoryAccuracy(),
    'sequence_accuracy': tl.MaskedSequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(),
                                    tl.Negate()),
    'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
Example #10
0
File: train.py Project: JEF1056/R5
    import sentencepiece as spm
    spm.SentencePieceTrainer.train(input=args.train, model_prefix=os.path.join(args.dir,'bpe'), train_extremely_large_corpus=True, input_sentence_size=100000, shuffle_input_sentence=True, vocab_size=args.vocab_size, model_type="bpe", character_coverage = 1, user_defined_symbols=['/n', "/b", "/t","/e"], bos_piece="/t", eos_piece="/e", bos_id=1,eos_id=2, pad_id=-1)

with open("config.json", "w") as f:
    json.dump([{"train":args.train, "validation": args.val}, args.max_length, args.dir], f)
from src.createtask import stream
teststream=stream(trax.fastmath.device_count(), "train", debug=True)
for _ in range(5):
    test=next(teststream)[0]
    print(f"(device count, tokens per device) = {test.shape}\n")
del teststream, test

# Training task.
train_task = training.TrainTask(
    labeled_data=stream(trax.fastmath.device_count(), "train"),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    lr_schedule=trax.lr.multifactor(),
    optimizer=trax.optimizers.Adam(),
    n_steps_per_checkpoint=1000,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=stream(trax.fastmath.device_count(), "validation"),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=10  # For less variance in eval numbers.
)

output_dir = os.path.expanduser(args.dir)

print("~~Begin Training~~")
Example #11
0
    return tl.base.Fn('LatentLossFunction', f)


def DropLast():
    """Drops the last stack element."""
    def f(x, u):
        return x

    return tl.Fn('DropLast', f)


Latent_METRICS = {
    'next_state_loss':
    tl.Serial(tl.Select([0, 1, 9]),
              tl.WeightedCategoryCrossEntropy()),  # DropLast()),
    'recon_state_loss':
    tl.Serial(tl.Select([2, 3, 10]), tl.WeightedCategoryCrossEntropy()),
    'recon_action_loss':
    tl.Serial(tl.Select([4, 5, 11]), tl.WeightedCategoryCrossEntropy()),
    'next_state_accuracy':
    tl.Serial(tl.Select([0, 1, 9]), tl.Accuracy()),  # DropLast()),
    'recon_state_accuracy':
    tl.Serial(tl.Select([2, 3, 10]), tl.Accuracy()),
    'recon_action_accuracy':
    tl.Serial(tl.Select([4, 5, 11]), tl.Accuracy()),
    'next_state_sequence_accuracy':
    tl.Serial(tl.Select([0, 1, 9]), tl.SequenceAccuracy()),  # DropLast()),
    'recon_state_sequence_accuracy':
    tl.Serial(tl.Select([2, 3, 10]), tl.SequenceAccuracy()),
    'recon_action_sequence_accuracy':