def BERTPretrainingLoss(): nsp_loss = [ tl.Select([0, 2, 3], n_in=6), tl.WeightedCategoryCrossEntropy() ] mlm_loss = [ tl.Select([1, 4, 5], n_in=6), tl.WeightedCategoryCrossEntropy() ] return tl.Serial(tl.Branch(nsp_loss, mlm_loss), tl.Add())
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ with fastmath.use_backend(fastmath.Backend.TFNP): n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.WeightedCategoryCrossEntropy(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir = self.create_tempdir().full_path trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or x.dtype == jnp.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_call_and_grad(self): layer_partial = tl.Serial( tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), sparsity.Favor(d_feature=4, n_heads=2), tl.Select([0], n_in=2), ) layer = tl.Serial( tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), sparsity.Favor(d_feature=4, n_heads=2), tl.Select([0], n_in=2), tl.WeightedCategoryCrossEntropy(), ) x = np.ones((1, 2), dtype=np.int32) w = np.ones_like(x).astype(np.float32) x_sig = shapes.signature(x) w_sig = shapes.signature(w) layer_partial.init(x_sig) y = layer_partial(x) self.assertEqual(y.shape, (1, 2, 4)) layer.init((x_sig, x_sig, w_sig)) y = layer((x, x, w)) self.assertEqual(y.shape, ()) state = layer.state rng = fastmath.random.get_prng(0) fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[ 0] g = fastmath.grad(fwd)(layer.weights, (x, x, w)) self.assertEqual(g[0][1][0].shape, (3, 4))
def _mnist_tasks(head=None): """Creates MNIST training and evaluation tasks. Args: head: Adaptor layer to put before loss and accuracy layers in the tasks. Returns: A pair (train_task, eval_task) consisting of the MNIST training task and the MNIST evaluation task using cross-entropy as loss and accuracy as metric. """ loss = tl.WeightedCategoryCrossEntropy() accuracy = tl.WeightedCategoryAccuracy() if head is not None: loss = tl.Serial(head, loss) accuracy = tl.Serial(head, accuracy) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), loss, adam.Adam(0.001), ) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [loss, accuracy], n_eval_batches=10, metric_names=['CrossEntropy', 'WeightedCategoryAccuracy'], ) return (task, eval_task)
def test_train_memory_efficient(self): """Trains a large network in a memory-efficient way.""" # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create the model. n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram model = tl.Serial( tl.Embedding(9, 16 * 1024), tl.Dup(), [[ tl.ReversibleHalfResidual(tl.Dense(16 * 1024)), tl.ReversibleSwap() ] for _ in range(n_layers)], tl.Concatenate(), tl.Dense(9), ) # Create inputs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) def _data_gen(): while True: yield labeled_batch # Run training. loss_layer = tl.WeightedCategoryCrossEntropy() task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) eval_task = training.EvalTask(_data_gen(), [tl.WeightedCategoryCrossEntropy()]) loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n == 2, use_memory_efficient_trainer=True) self.assertEqual(0, loop.step) loop.run(n_steps=2) self.assertEqual(2, loop.step)
def test_weighted_category_cross_entropy(self): layer = tl.WeightedCategoryCrossEntropy() targets = np.array([0, 1]) weights = np.array([30, 10]) # Near-perfect prediction (for both items in batch). model_outputs = np.array([[9., 2., 0., -2.], [2., 9., 0., -2.]]) loss = layer([model_outputs, targets, weights]) self.assertAlmostEqual(loss, .001, places=3) # More right than wrong (for both items in batch). model_outputs = np.array([[2.2, 2., 0., -2.], [2., 2.2, 0., -2.]]) loss = layer([model_outputs, targets, weights]) self.assertAlmostEqual(loss, .665, places=3) # First item (with 75% weight) near perfect, second more right than wrong. model_outputs = np.array([[9., 2., 0., -2.], [2., 2.2, 0., -2.]]) loss = layer([model_outputs, targets, weights]) self.assertAlmostEqual(loss, .167, places=3)
def test_reset_twice(self, backend): with fastmath.use_backend(backend): n_classes = 4 model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.WeightedCategoryCrossEntropy(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir1 = self.create_tempdir(name='output_dir1').full_path trainer.reset(output_dir1) trainer.evaluate(1) output_dir2 = self.create_tempdir(name='output_dir2').full_path trainer.reset(output_dir2) trainer.evaluate(1)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.WeightedCategoryCrossEntropy(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, permanent_checkpoints_at=None, eval_steps=10, eval_frequency=100, permanent_checkpoint_frequency=None, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False, init_checkpoint=None, callbacks=None, additional_train_tasks=None, additional_eval_tasks=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. permanent_checkpoints_at: list of integers. Save a permanent checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. permanent_checkpoint_frequency: int, how often to save permanent checkpoints (every permanent_checkpoint_frequency steps). random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer.. init_checkpoint: a checkpoint for fine tuning. callbacks: a list of callbacks to call during training. additional_train_tasks: additional tasks which should be performed during training. additional_eval_tasks: additional tasks which should be performed during evaluation. Returns: trax.TrainerState or training.Loop if use_loop is True """ if (permanent_checkpoint_frequency is not None and permanent_checkpoints_at is not None): raise ValueError('Only one of ["permanent_checkpoint_frequency", ' '"permanent_checkpoints_at"] should be set.') if use_loop: n_devices = num_devices() or fastmath.device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable( inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask( inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency, n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = ( lambda step: step in permanent_checkpoints_at) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') if init_checkpoint: model_train.init_from_file(init_checkpoint, weights_only=True) model_predict_eval.init_from_file(init_checkpoint, weights_only=True) loop = training.Loop( model_train, [train_task] + (additional_train_tasks if additional_train_tasks is not None else []), eval_model=model_predict_eval, eval_tasks=[eval_task] + (additional_eval_tasks if additional_eval_tasks is not None else []), output_dir=output_dir, checkpoint_at=checkpoint_at, permanent_checkpoint_at=permanent_checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, random_seed=random_seed, callbacks=callbacks, ) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest, init_checkpoint=init_checkpoint) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
'step', # Current training step number. 'opt_state', # OptState. 'history', # trax.history.History. 'model_state', # Auxilliary state of the model. ]) OptState = collections.namedtuple( '_OptState', [ 'weights', # Model weights. 'slots', # Per-parameter optimizer state, e.g. gradient moments. 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. ]) _DEFAULT_METRICS = { 'loss': tl.WeightedCategoryCrossEntropy(), 'accuracy': tl.WeightedCategoryAccuracy(), 'sequence_accuracy': tl.MaskedSequenceAccuracy(), 'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(), tl.Negate()), 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), } class Trainer: """Trax trainer. A trainer allows to make training steps, train for full epochs, save the training state and access evaluation data. """ def __init__(self,
import sentencepiece as spm spm.SentencePieceTrainer.train(input=args.train, model_prefix=os.path.join(args.dir,'bpe'), train_extremely_large_corpus=True, input_sentence_size=100000, shuffle_input_sentence=True, vocab_size=args.vocab_size, model_type="bpe", character_coverage = 1, user_defined_symbols=['/n', "/b", "/t","/e"], bos_piece="/t", eos_piece="/e", bos_id=1,eos_id=2, pad_id=-1) with open("config.json", "w") as f: json.dump([{"train":args.train, "validation": args.val}, args.max_length, args.dir], f) from src.createtask import stream teststream=stream(trax.fastmath.device_count(), "train", debug=True) for _ in range(5): test=next(teststream)[0] print(f"(device count, tokens per device) = {test.shape}\n") del teststream, test # Training task. train_task = training.TrainTask( labeled_data=stream(trax.fastmath.device_count(), "train"), loss_layer=tl.WeightedCategoryCrossEntropy(), lr_schedule=trax.lr.multifactor(), optimizer=trax.optimizers.Adam(), n_steps_per_checkpoint=1000, ) # Evaluaton task. eval_task = training.EvalTask( labeled_data=stream(trax.fastmath.device_count(), "validation"), metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()], n_eval_batches=10 # For less variance in eval numbers. ) output_dir = os.path.expanduser(args.dir) print("~~Begin Training~~")
return tl.base.Fn('LatentLossFunction', f) def DropLast(): """Drops the last stack element.""" def f(x, u): return x return tl.Fn('DropLast', f) Latent_METRICS = { 'next_state_loss': tl.Serial(tl.Select([0, 1, 9]), tl.WeightedCategoryCrossEntropy()), # DropLast()), 'recon_state_loss': tl.Serial(tl.Select([2, 3, 10]), tl.WeightedCategoryCrossEntropy()), 'recon_action_loss': tl.Serial(tl.Select([4, 5, 11]), tl.WeightedCategoryCrossEntropy()), 'next_state_accuracy': tl.Serial(tl.Select([0, 1, 9]), tl.Accuracy()), # DropLast()), 'recon_state_accuracy': tl.Serial(tl.Select([2, 3, 10]), tl.Accuracy()), 'recon_action_accuracy': tl.Serial(tl.Select([4, 5, 11]), tl.Accuracy()), 'next_state_sequence_accuracy': tl.Serial(tl.Select([0, 1, 9]), tl.SequenceAccuracy()), # DropLast()), 'recon_state_sequence_accuracy': tl.Serial(tl.Select([2, 3, 10]), tl.SequenceAccuracy()), 'recon_action_sequence_accuracy':