def test_train_save_restore_sharded(self): """Saves and restores a sharded checkpoint to check for equivalence.""" if fastmath.local_device_count() < 2: return # multi-accelerator only base.N_WEIGHTS_SHARDS = fastmath.local_device_count() train_data = data.Serial(lambda _: _very_simple_data(2, 2), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(2, 2), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(2)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts _, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint('model') _, training_session2 = _make_model_and_session() training_session2.run(n_steps=1) base.N_WEIGHTS_SHARDS = 1
def test_train_mnist_multitask(self, mock_stdout): """Train two-head MNIST model a bit, to compare to other implementations.""" mnist_model = _build_model(two_heads=True) # MNIST classification task. (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2)) # Auxiliary brightness prediction task. reg_task = training.TrainTask( itertools.cycle(_mnist_brightness_dataset().train_stream(1)), tl.Serial(tl.Select([1]), tl.L2Loss()), adam.Adam(0.001), ) reg_eval_task = training.EvalTask( itertools.cycle(_mnist_brightness_dataset().eval_stream(1)), [tl.Serial(tl.Select([1]), tl.L2Loss())], n_eval_batches=1, metric_names=['L2'], ) training_session = training.Loop( mnist_model, tasks=[cls_task, reg_task], eval_tasks=[cls_eval_task, reg_eval_task], eval_at=lambda step_n: step_n % 20 == 0, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=100) self.assertEqual(training_session.step, 100) # Assert that we reach at least 80% eval accuracy on MNIST. self.assertGreater(_read_metric('Accuracy', mock_stdout), 0.8) # Assert that we get below 0.03 brightness prediction error. self.assertLess(_read_metric('L2', mock_stdout), 0.03)
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) tasks = tuple( training.TrainTask( # pylint: disable=g-complex-comprehension _very_simple_data(output_dim), tl.L2Loss(), optimizers.SGD(.01), ) for output_dim in (1, 2)) eval_tasks = tuple([ training.EvalTask( # pylint: disable=g-complex-comprehension # deliberately re-using training data _very_simple_data(output_dim), [tl.L2Loss()], ) ] for output_dim in (1, 2)) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_train_one_task_eval_two_tasks(self): """Trains a very simple network on one task and evaluates on two tasks.""" model = tl.Serial(tl.Dense(3), tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) export_prefix_1 = 'eval_1' eval_task_1 = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], export_prefix=export_prefix_1, ) export_prefix_2 = 'eval_2' eval_task_2 = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], export_prefix=export_prefix_2, ) training_session = training.Loop( model, tasks=(task, ), eval_tasks=(eval_task_1, eval_task_2), ) self.assertEqual(0, training_session.step) training_session.run(n_steps=5) self.assertEqual(5, training_session.step) export_prefixes = [ task.export_prefix for task in training_session.eval_tasks ] self.assertCountEqual([export_prefix_1, export_prefix_2], export_prefixes)
def test_restores_history(self): """Training restores history from directory where it saved it.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()]) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir) self.assertLen(loop2.history.modes, 2) self.assertLen(loop2.history.metrics_for_mode('train'), 6) self.assertLen(loop2.history.metrics_for_mode('eval'), 1) for mode, metric in [ ('train', 'metrics/L2Loss'), ('train', 'training/learning_rate'), ('train', 'training/steps per second'), ('train', 'training/gradients_l2'), ('train', 'training/loss'), ('train', 'training/weights_l2'), ('eval', 'metrics/L2Loss'), ]: self.assertLen(loop2.history.get(mode, metric), 1) self.assertEqual(2, loop2.history.get(mode, metric)[0][0])
def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint() model2, training_session2 = _make_model_and_session() x = np.ones((8, 1)) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2))
def test_summaries_are_written(self): """Training writes down metrics when writting is turned on.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) expected_train_metric_dir = os.path.join(tmp_dir, 'train') expected_eval_metric_dir = os.path.join(tmp_dir, 'eval') for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertFalse( os.path.isdir(directory), 'Failed for directory %s.' % directory) training_session.run(n_steps=15) time.sleep(1) # wait for the files to be closed for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertTrue( os.path.isdir(directory), 'Failed for directory %s.' % directory) self.assertEqual( 1, _count_files(directory), 'Failed for directory %s.' % directory) training_session.run(n_steps=5) time.sleep(1) # wait for the files to be closed for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertEqual( 2, _count_files(directory), 'Failed for directory %s.' % directory)
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) train_tasks, eval_tasks = [], [] for output_dim in [1, 2]: # The head we select from the model: 0 for output_dim 1 and 1 for 2. head_index = output_dim - 1 train_tasks.append( training.TrainTask( _very_simple_data(output_dim), tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), optimizers.SGD(.01))) eval_tasks.append( training.EvalTask( _very_simple_data( output_dim), # deliberately re-use training data [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())])) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=train_tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_train_dense_layer_with_momentum(self): """Trains with an optimizer that has slots / requires initialization.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.Momentum(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['Momentum.L2Loss']) training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0) self.assertEqual(0, training_session.step) training_session.run(n_steps=20) self.assertEqual(20, training_session.step)
def test_train_dense_layer(self): """Trains a very simple network on a very simple task.""" model = tl.Dense(1) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), sgd.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], names=['SGD.L2Loss'], eval_at=lambda step_n: step_n % 2 == 0, eval_N=1) training_session = training.Loop(model, task, eval_task=eval_task) self.assertIsNone(training_session.current_step()) training_session.run(n_steps=20) self.assertEqual(20, training_session.current_step())
def test_train_dense_layer_with_momentum(self): """Trains with an optimizer that has slots / requires initialization.""" model = tl.Dense(1) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), momentum.Momentum(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], names=['Momentum.L2Loss'], eval_at=lambda step_n: step_n % 2 == 0, eval_N=1) training_session = training.Loop(model, task, eval_task=eval_task) self.assertIsNone(training_session.current_step()) training_session.run(n_steps=20) self.assertEqual(20, training_session.current_step())
def test_train_dense_layer_evals(self): """Trains a very simple network on a very simple task, 2 epochs.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()]) training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: False) self.assertEqual(0, training_session.step) training_session.run(n_steps=10) self.assertEqual(10, training_session.step) training_session.run_evals() self.assertEqual(10, training_session.step) # Unchanged
def test_calls_step_callbacks(self): """Runs a training loop, asserting that callbacks are called.""" call_at_steps = [1, 3, 4] begin_steps = [] end_steps = [] test_case = self class TestCallback(callbacks.TrainingStepCallback): def call_at(self, step): return step in call_at_steps def on_step_begin(self, step): begin_steps.append(step) def on_step_end(self, step): # Assert that on_step_begin() was called before. test_case.assertIn(step, begin_steps) end_steps.append(step) model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) loop = training.Loop(model, [task], callbacks=[TestCallback]) loop.run(n_steps=5) # Assert that the callback has been called at the appropriate steps. self.assertEqual(begin_steps, call_at_steps) self.assertEqual(end_steps, call_at_steps)
def test_initializes_step_callbacks_with_loop_instance(self): """Runs a training loop, asserting that callbacks are initialized.""" class ActualLoop: # Wrapper object to make the Loop reference mutable. loop = None class TestCallback(callbacks.TrainingStepCallback): def __init__(self, loop): super().__init__(loop) ActualLoop.loop = loop def call_at(self, step): return False def on_step_begin(self, step): del step def on_step_end(self, step): del step model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) expected_loop = training.Loop(model, [task], callbacks=[TestCallback]) self.assertIs(ActualLoop.loop, expected_loop)
def test_loop_checkpoint_high_metric(self): """Runs a training loop that saves checkpoints for high metric values.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_metric = tl.L2Loss() eval_task = training.EvalTask(_very_simple_data(), [eval_metric], metric_names=['l2_loss']) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], eval_tasks=[eval_task], output_dir=tmp_dir, eval_at=lambda step_n: step_n % 2 == 0, checkpoint_at=lambda step_n: step_n % 2 == 0, checkpoint_high_metric='l2_loss') loop.run(n_steps=18)
def test_loop_with_initialized_model(self): """Check that loop does not re-initialize an already initialized model.""" model = tl.Serial(tl.Dense(1)) example_data = next(_very_simple_data()) model.init(example_data) w = model.weights[0][0] task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0) self.assertEqual(0, loop.step) self.assertEqual(loop.model.weights[0][0], w)
def test_train_dense_layer(self): """Trains a very simple network on a very simple task.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0) self.assertEqual(0, training_session.step) training_session.run(n_steps=15) self.assertEqual(15, training_session.step) training_session.run(n_steps=5) self.assertEqual(20, training_session.step)
def __init__( self, trajectory_batch_stream, optimizer, lr_schedule, advantage_estimator, model, target_model=None, sync_at=(lambda step: step % 100 == 0), loss_layer=None, head_selector=(), ): """Initializes ValueTrainTask. Args: trajectory_batch_stream: Generator of trax.rl.task.TrajectoryNp. optimizer: Optimizer for network training. lr_schedule: Learning rate schedule for network training. advantage_estimator: Function (rewards, returns, values, dones) -> advantages, created by one of the functions from trax.rl.advantages. model: Model being trained, used to synchronize weights of the target model. target_model: Model for calculating TD targets. If `None`, use `model`. sync_at: Function step -> bool, indicating when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using TD-k. loss_layer: The value loss layer. The default is L2 loss. head_selector: Layer to apply to the network output to select the value head. Only needed in multitask training. """ self._trajectory_batch_stream = trajectory_batch_stream self._advantage_estimator = advantage_estimator self._sync_at = sync_at self._head_selector = head_selector def attach_head(model): return tl.Serial(model, self._head_selector) self._train_model = attach_head(model) if target_model is None: target_model = model # TODO(pkozakowski): Use target_model.clone() once it's implemented. self._target_model = attach_head(copy.deepcopy(target_model)) # Count the steps, so we know when to synchronize the target network. self._step = None labeled_data = ( self.value_batch(trajectory_batch) for (self._step, trajectory_batch) in enumerate( self._trajectory_batch_stream ) ) if loss_layer is None: loss_layer = tl.L2Loss() loss_layer = tl.Serial(head_selector, loss_layer) super().__init__( labeled_data, loss_layer, optimizer, lr_schedule=lr_schedule, )
def test_restores_memory_efficient_from_standard(self): """Training restores step from directory where it saved it.""" model = tl.Serial(tl.Dense(4), tl.Dense(1)) task_std = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task_std], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) task_memeff = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam) loop2 = training.Loop(model, [task_memeff], output_dir=tmp_dir, use_memory_efficient_trainer=True) loop2.run(2) self.assertEqual(6, loop2.step)
def test_train_one_task_eval_two_tasks(self): """Trains a very simple network on one task and evaluates on two tasks.""" model = tl.Serial(tl.Dense(3), tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], ) training_session = training.Loop( model, tasks=(task, ), eval_tasks=(eval_task, eval_task), ) self.assertEqual(0, training_session.step) training_session.run(n_steps=5) self.assertEqual(5, training_session.step)
def test_loop_no_eval_task(self): """Runs a training loop with no eval task(s).""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) training_session = training.Loop(model, [task]) # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5)
def test_loop_no_eval_task_tfnp(self): """Runs a training loop with no eval task(s), TFNP backend.""" with fastmath.use_backend(fastmath.Backend.TFNP): model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) training_session = training.Loop(model, [task]) # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5)
def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" train_data = data.Serial(lambda _: _very_simple_data(), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint() self.assertEqual(data.inputs.data_counters['simple_data'], 2) data.inputs.data_counters['simple_data'] = 0 # reset manually self.assertEqual(data.inputs.data_counters['simple_data'], 0) # check model2, training_session2 = _make_model_and_session() self.assertEqual(data.inputs.data_counters['simple_data'], 2) # restored x = np.ones((8, 1)) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2)) slots1 = training_session._trainer_per_task[0].slots slots2 = training_session2._trainer_per_task[0].slots np.testing.assert_array_equal(slots1, slots2)
def test_names(self): layer = tl.L2Loss() self.assertEqual('L2Loss_in3', str(layer)) layer = tl.Accuracy() self.assertEqual('Accuracy_in3', str(layer)) layer = tl.SequenceAccuracy() self.assertEqual('SequenceAccuracy_in3', str(layer)) layer = tl.CrossEntropyLoss() self.assertEqual('CrossEntropyLoss_in3', str(layer)) layer = tl.CrossEntropySum() self.assertEqual('CrossEntropySum_in3', str(layer))
def test_trains_on_two_tasks(self): """Trains a very simple network on two very simple tasks.""" model = tl.Serial(tl.Dense(3), tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], ) training_session = training.Loop( model, tasks=(task, task), eval_tasks=(eval_task, eval_task), which_task=lambda step_n: step_n % 2, ) self.assertEqual(0, training_session.step) training_session.run(n_steps=15) self.assertEqual(15, training_session.step) training_session.run(n_steps=5) self.assertEqual(20, training_session.step)
def test_restores_step(self): """Training restores step from directory where it saved it.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir) self.assertEqual(4, loop2.step)
def test_run_simple_task(self): """Runs an accelerated optimizer on a simple task.""" inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch targets_batch = np.pi * np.ones_like(inputs_batch) labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) loss_layer.init(labeled_batch) optimizer = optimizers.SGD(.01) optimizer.tree_init(loss_layer.weights) trainer = optimizers.Trainer(loss_layer, optimizer) rng = fastmath.random.get_prng(0) trainer.one_step(labeled_batch, rng)
def test_l2_loss(self): layer = tl.L2Loss() model_outputs = np.array([[1., 1.], [1., 1.]]) targets = np.array([[1., 1.], [1., 0.]]) weights = np.array([[1., 1.], [1., 0.]]) loss = layer([model_outputs, targets, weights]) np.testing.assert_allclose(loss, 0.0) weights = np.array([[1., 0.], [0., 1.]]) loss = layer([model_outputs, targets, weights]) np.testing.assert_allclose(loss, 0.5)
def test_train_save_restore_transformer(self): """Saves and restores a checkpoint to check for equivalence.""" vocab_size = 8 task = training.TrainTask(_very_simple_transformer_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_transformer_data( ), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = transformer.TransformerLM(vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint('model') model2, training_session2 = _make_model_and_session() x = np.ones((2, 2)).astype(np.int32) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2))
def test_restore_fails_different_model(self): """Training restores from a checkpoint created with smaller model.""" model1 = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model1, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(2) model2 = tl.Serial(tl.Dense(2)) with self.assertRaises(IndexError): training.Loop(model2, [task], output_dir=tmp_dir)