def test_trains_on_two_tasks(self): """Trains a very simple network on two very simple tasks.""" model = tl.Serial(tl.Dense(3), tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01) ) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], ) training_session = training.Loop( model, tasks=(task, task), eval_tasks=(eval_task, eval_task), which_task=lambda step_n: step_n % 2, ) self.assertEqual(0, training_session.step) training_session.run(n_steps=15) self.assertEqual(15, training_session.step) training_session.run(n_steps=5) self.assertEqual(20, training_session.step)
def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint() model2, training_session2 = _make_model_and_session() x = np.ones((8, 1)) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2)) slots1 = training_session._trainer_per_task[0].slots slots2 = training_session2._trainer_per_task[0].slots np.testing.assert_array_equal(slots1, slots2)
def test_train_memory_efficient(self): """Trains a large network in a memory-efficient way.""" # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create the model. n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram model = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup(), [[ tl.ReversibleHalfResidual(tl.Dense(16 * 1024)), tl.ReversibleSwap() ] for _ in range(n_layers)], tl.Concatenate(), tl.Dense(9), tl.LogSoftmax()) # Create inputs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) def _data_gen(): while True: yield labeled_batch # Run training. task = training.TrainTask(_data_gen(), tl.CrossEntropyLoss(), optimizers.Adafactor) eval_task = training.EvalTask(_data_gen(), [tl.CrossEntropyLoss()]) loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n == 2, use_memory_efficient_trainer=True) self.assertEqual(0, loop.step) loop.run(n_steps=2) self.assertEqual(2, loop.step)
def test_loop_checkpoint_high_metric(self): """Runs a training loop that saves checkpoints for high metric values.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_metric = tl.L2Loss() eval_task = training.EvalTask(_very_simple_data(), [eval_metric], metric_names=['l2_loss']) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], eval_tasks=[eval_task], output_dir=tmp_dir, eval_at=lambda step_n: step_n % 2 == 0, checkpoint_at=lambda step_n: step_n % 2 == 0, checkpoint_high_metric='l2_loss') call_counter = collections.Counter() loop.save_checkpoint = lambda name: call_counter.update([name]) loop.run(n_steps=10) # Eval metric steadily descends, so high checkpoint triggered only once. # Low checkpoint not defined, so never triggered. self.assertEqual(call_counter['model'], 5) self.assertEqual(call_counter['lowest_l2_loss'], 0) self.assertEqual(call_counter['highest_l2_loss'], 1)
return model model = NMTAttn() # print(model) train_task = training.TrainTask( labeled_data=train_batch_data, loss_layer=tl.CrossEntropyLoss(), optimizer=trax.optimizers.Adam(0.01), lr_schedule=trax.lr.warmup_and_rsqrt_decay(1000, 0.01), n_steps_per_checkpoint=20, ) eval_task = training.EvalTask( labeled_data=eval_batch_data, metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], ) output_dir = 'Nueral_Machine_Translation_With_Attention/output_dir/' model_file_path = os.path.join(output_dir,"model.pkl.gz") # # remove old model if it exists. restarts training. if os.path.exists(model_file_path): os.remove(model_file_path) # define the training loop training_loop = training.Loop(NMTAttn(mode='train'), train_task, eval_tasks=[eval_task], output_dir=output_dir) training_loop.run(3)
def train_model(model, data_generator, batch_size=32, max_length=64, lines=lines, eval_lines=eval_lines, n_steps=1, output_dir='model/'): """Function that trains the model Args: model (trax.layers.combinators.Serial): GRU model. data_generator (function): Data generator function. batch_size (int, optional): Number of lines per batch. Defaults to 32. max_length (int, optional): Maximum length allowed for a line to be processed. Defaults to 64. lines (list, optional): List of lines to use for training. Defaults to lines. eval_lines (list, optional): List of lines to use for evaluation. Defaults to eval_lines. n_steps (int, optional): Number of steps to train. Defaults to 1. output_dir (str, optional): Relative path of directory to save model. Defaults to "model/". Returns: trax.supervised.training.Loop: Training loop for the model. """ ### START CODE HERE (Replace instances of 'None' with your code) ### bare_train_generator = data_generator(batch_size=batch_size, max_length=max_length, data_lines=lines) infinite_train_generator = itertools.cycle(bare_train_generator) bare_eval_generator = data_generator(batch_size=batch_size, max_length=max_length, data_lines=eval_lines) infinite_eval_generator = itertools.cycle(bare_eval_generator) train_task = training.TrainTask( labeled_data= infinite_train_generator, # Use infinite train data generator loss_layer=tl.CrossEntropyLoss( ), # Don't forget to instantiate this object optimizer=trax.optimizers.Adam( 0.0005) # Don't forget to add the learning rate parameter ) eval_task = training.EvalTask( labeled_data=infinite_eval_generator, # Use infinite eval data generator metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], # Don't forget to instantiate these objects n_eval_batches=3 # For better evaluation accuracy in reasonable time ) training_loop = training.Loop(model, train_task, eval_task=eval_task, output_dir=output_dir) training_loop.run(n_steps=n_steps) ### END CODE HERE ### # We return this because it contains a handle to the model, which has the weights etc. return training_loop
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss(), name='CrossEntropyLoss'), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. Returns: trax.TrainerState or training.Loop if use_loop is True """ if use_loop: n_devices = num_devices() or fastmath.device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable( inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask(inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency) # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at loop = training.Loop( model(mode='train'), [train_task], eval_model=model(mode='eval'), eval_tasks=[eval_task], output_dir=output_dir, checkpoint_at=checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, random_seed=random_seed) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.WeightedCategoryCrossEntropy(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, permanent_checkpoints_at=None, eval_steps=10, eval_frequency=100, permanent_checkpoint_frequency=None, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, init_checkpoint=None, callbacks=None, additional_train_tasks=None, additional_eval_tasks=None, additional_eval_streams=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. permanent_checkpoints_at: list of integers. Save a permanent checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. permanent_checkpoint_frequency: int, how often to save permanent checkpoints (every permanent_checkpoint_frequency steps). random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. adasum: if True, use adaptive summation for multi-device gradients. init_checkpoint: a checkpoint for fine tuning. callbacks: a list of callbacks to call during training. additional_train_tasks: additional tasks which should be performed during training. additional_eval_tasks: additional tasks which should be performed during evaluation. additional_eval_streams: List[NamedStream], additional data streams that should be used during evaluation. Can be provided independently of additional_eval_tasks. Returns: trax.TrainerState or training.Loop if use_loop is True """ if (permanent_checkpoint_frequency is not None and permanent_checkpoints_at is not None): raise ValueError('Only one of ["permanent_checkpoint_frequency", ' '"permanent_checkpoints_at"] should be set.') if use_loop: n_devices = num_devices() or fastmath.local_device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable(inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask( inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency, n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) if additional_train_tasks is None: additional_train_tasks = [] # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) if additional_eval_tasks is None: additional_eval_tasks = [] additional_eval_tasks_from_streams = [] if additional_eval_streams is not None: for stream in additional_eval_streams: additional_eval_tasks_from_streams.append( training.EvalTask(stream.stream, metrics, metric_names=names, n_eval_batches=eval_steps, export_prefix=stream.name)) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') if init_checkpoint: model_train.init_from_file(init_checkpoint, weights_only=True) model_predict_eval.init_from_file(init_checkpoint, weights_only=True) loop = training.Loop( model_train, [train_task] + additional_train_tasks, eval_model=model_predict_eval, eval_tasks=[eval_task] + additional_eval_tasks + additional_eval_tasks_from_streams, output_dir=output_dir, checkpoint_at=checkpoint_at, permanent_checkpoint_at=permanent_checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, adasum=adasum, random_seed=random_seed, callbacks=callbacks, ) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest, init_checkpoint=init_checkpoint) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def train(self, callbacks, epochs, loss, metrics, metric_emit_freq, optimizer, save_directory, output_type='infer', writer=stdout, n_eval_batches=10, batch_size_per_device=256, eval_batch_size=256, variable_shapes=False, *args, **kwargs): """ Run the training loop for your ML pipeline. :param callbacks: Collection of callables that are run inside the training loop :type callbacks: ```None or List[Callable] or Tuple[Callable]``` :param epochs: number of epochs (must be greater than 0) :type epochs: ```int``` :param loss: Loss function, can be a string (depending on the framework) or an instance of a class :type loss: ```str or Callable or Any``` :param metrics: Collection of metrics to monitor, e.g., accuracy, f1 :type metrics: ```None or List[Callable or str] or Tuple[Callable or str]``` :param metric_emit_freq: Frequency of metric emission, e.g., `lambda: epochs % 10 == 0`, defaults to every epoch :type metric_emit_freq: ```None or (*args, **kwargs) -> bool``` :param optimizer: Optimizer, can be a string (depending on the framework) or an instance of a class :type callbacks: ```str or Callable or Any``` :param save_directory: Directory to save output in, e.g., weights in h5 files. If None, don't save. :type save_directory: ```None or str``` :param output_type: `if save_directory is not None` then save in this format, e.g., 'h5'. :type output_type: ```str``` :param writer: Writer for all output, could be a TensorBoard instance, a file handler like stdout or stderr :type writer: ```stdout or Any``` :param n_eval_batches: :type n_eval_batches: ```int``` :param batch_size_per_device: :type batch_size_per_device: ```int``` :param eval_batch_size: :type eval_batch_size: ```int``` :param variable_shapes: :type variable_shapes: ```bool``` :param args: :param kwargs: :return: """ super(TraxTrainer, self).train(callbacks=callbacks, epochs=epochs, loss=loss, metrics=metrics, metric_emit_freq=metric_emit_freq, optimizer=optimizer, save_directory=save_directory, output_type='infer', writer=writer, *args, **kwargs) assert self.data is not None assert self.model is not None task = training.TrainTask(itertools.cycle(self.data.train_stream(1)), loss, optimizer) eval_task = training.EvalTask(itertools.cycle( self.data.eval_stream(1)), metrics, n_eval_batches=n_eval_batches) training_session = training.Loop(self.model, task, eval_task=eval_task, eval_at=metric_emit_freq) training_session.run(n_steps=epochs) return training_session
train_data_pipeline = trax.data.Serial( trax.data.Shuffle(), trax.data.Batch(8), ) train_batches_stream = train_data_pipeline(train_stream) eval_data_pipeline = trax.data.Batch(1) eval_batches_stream = eval_data_pipeline(eval_stream) # Define Train and Eval tasks using Trax Training train_task = training.TrainTask( labeled_data=train_batches_stream, loss_layer=tl.CategoryCrossEntropy(), optimizer=trax.optimizers.Adam(args.learning_rate), ) eval_task = training.EvalTask( labeled_data=eval_batches_stream, metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()], n_eval_batches=20, ) # Train Model model = get_model(n_output_classes=10) training_loop = training.Loop(model, train_task, eval_tasks=[eval_task]) training_loop.run(args.train_steps) # Save Model save_model_tf(model)
print(f"(device count, tokens per device) = {test.shape}\n") del teststream, test # Training task. train_task = training.TrainTask( labeled_data=stream(trax.fastmath.device_count(), "train"), loss_layer=tl.WeightedCategoryCrossEntropy(), lr_schedule=trax.lr.multifactor(), optimizer=trax.optimizers.Adam(), n_steps_per_checkpoint=1000, ) # Evaluaton task. eval_task = training.EvalTask( labeled_data=stream(trax.fastmath.device_count(), "validation"), metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()], n_eval_batches=10 # For less variance in eval numbers. ) output_dir = os.path.expanduser(args.dir) print("~~Begin Training~~") # Train tiny model with Loop. training_loop = training.Loop( trax.models.ReformerLM(mode="train"), train_task, eval_tasks=[eval_task], output_dir=output_dir) # run 1000 steps (batches) training_loop.run(1000000)