def train_tnt_and_reference_models(model_config, optimizer, micro_batch_size, nbatches, number_epochs, optimizer_kwargs={}): (train_dataset, _) = util.train_test_mnist_datasets(nbatches=nbatches, micro_batch_size=micro_batch_size) (ref_train_dataset, _) = util.train_test_mnist_datasets(nbatches=nbatches, micro_batch_size=micro_batch_size) tnt_model_runner, ref_model_runner = get_compiled_models( model_config, optimizer, **optimizer_kwargs) tnt_history = tnt_model_runner.train_model(train_dataset, number_epochs) ref_history = ref_model_runner.train_model(ref_train_dataset, number_epochs) rank = tnt.get_rank() logging.getLogger().info(f"[Rank {rank}] Tarantella (loss, accuracy) = " f"({tnt_history.history})") logging.getLogger().info(f"[Rank {rank}] Reference (loss, accuracy) = " f"({ref_history.history})") return tnt_history, ref_history
def test_save_load_train_models(self, model, save_setup, parallel_strategy, optimizer_type, check_configuration_identical): train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # create and train model tnt_model = get_tnt_model_compiled(model, parallel_strategy, optimizer_type()) tnt_model.fit(train_dataset, epochs=2, verbose=0) tnt_model.save(save_setup['save_dir'], tnt_save_all_devices=save_setup['all_devices']) # load into a new tnt.Model reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir']) assert isinstance(reloaded_tnt_model, tnt.Model) check_configuration_identical(reloaded_tnt_model, tnt_model) # continue training on the original model tnt_model.fit(train_dataset, epochs=2, verbose=0) # continue training on the loaded model reloaded_tnt_model.fit(train_dataset, epochs=2, verbose=0) util.compare_weights(reloaded_tnt_model.get_weights(), tnt_model.get_weights(), 1e-6)
def test_keras_models(self, model, save_setup, optimizer, check_configuration_identical): train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # train model keras_model = model keras_model.compile( optimizer(), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) keras_model.fit(train_dataset, epochs=2, verbose=0) keras_model.save(save_setup['save_dir']) reloaded_model = keras.models.load_model(save_setup['save_dir']) check_configuration_identical(reloaded_model, keras_model) # continue training on the original model keras_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0) # continue training on the loaded model reloaded_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0) util.compare_weights(reloaded_model.get_weights(), keras.get_weights(), 1e-6)
def test_compare_accuracy_against_reference(self, model_runners, micro_batch_size, number_epochs, nbatches, test_nbatches, remainder_samples_per_batch, last_incomplete_batch_size): (train_dataset, test_dataset) = util.train_test_mnist_datasets( nbatches=nbatches, test_nbatches=test_nbatches, micro_batch_size=micro_batch_size, shuffle=False, remainder_samples_per_batch=remainder_samples_per_batch, last_incomplete_batch_size=last_incomplete_batch_size) (ref_train_dataset, ref_test_dataset) = util.train_test_mnist_datasets( nbatches=nbatches, test_nbatches=test_nbatches, micro_batch_size=micro_batch_size, shuffle=False, remainder_samples_per_batch=remainder_samples_per_batch, last_incomplete_batch_size=last_incomplete_batch_size) tnt_model_runner, reference_model_runner = model_runners reference_model_runner.train_model(ref_train_dataset, number_epochs) tnt_model_runner.train_model(train_dataset, number_epochs) tnt_loss_accuracy = tnt_model_runner.evaluate_model(test_dataset) ref_loss_accuracy = reference_model_runner.evaluate_model( ref_test_dataset) rank = tnt.get_rank() logging.getLogger().info( f"[Rank {rank}] Tarantella[loss, accuracy] = {tnt_loss_accuracy}") logging.getLogger().info( f"[Rank {rank}] Reference [loss, accuracy] = {ref_loss_accuracy}") result = [True, True] if tnt.is_master_rank(): result = [ np.isclose(tnt_loss_accuracy[0], ref_loss_accuracy[0], atol=1e-2), # losses might not be identical np.isclose(tnt_loss_accuracy[1], ref_loss_accuracy[1], atol=1e-6) ] util.assert_on_all_ranks(result)
def test_load_model_with_compile_flag(self, model, save_setup, parallel_strategy, load_compiled_model): tnt_model = get_tnt_model_compiled(model, parallel_strategy) train_dataset, _ = util.train_test_mnist_datasets(nbatches=1, micro_batch_size=32) tnt_model.save(save_setup['save_dir'], tnt_save_all_devices=save_setup['all_devices']) reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir'], compile=load_compiled_model) if not load_compiled_model: # if the model is not compiled, training should not succeed with pytest.raises(RuntimeError): reloaded_tnt_model.fit(train_dataset, epochs=1, verbose=0) else: # load compiled model # should be able to train the model if it was previously compiled reloaded_tnt_model.fit(train_dataset, epochs=1, verbose=0)
def test_save_weights_after_training(self, model, save_setup, parallel_strategy, tf_format): # create un-shuffling dataset to be able to continue training identically # both on the `tnt.Model` and then on the `keras.Model` train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # create and train model tnt_model = get_tnt_model_compiled(model, parallel_strategy) tnt_model.fit(train_dataset, epochs=1, verbose=0) os.makedirs(save_setup['save_dir'], exist_ok=True) save_path = os.path.join(save_setup['save_dir'], "weight") if not tf_format: save_path = save_path + ".h5" tnt_model.save_weights(save_path, tnt_save_all_devices=save_setup['all_devices']) # create new model with same architecture and optimizer if isinstance(model, tf.keras.Sequential): model_from_config = tnt.Sequential.from_config( tnt_model.get_config()) elif isinstance(model, tf.keras.Model): model_from_config = tnt.models.model_from_config( tnt_model.get_config()) model_from_config.compile(**get_compile_params()) model_from_config.load_weights(save_path) util.compare_weights(tnt_model.get_weights(), model_from_config.get_weights(), 1e-6) # using the TF format saves the state together with the weights # such that training can continue on the loaded model if tf_format: tnt_model.fit(train_dataset, epochs=1, verbose=0) model_from_config.fit(train_dataset, epochs=1, verbose=0) util.compare_weights(tnt_model.get_weights(), model_from_config.get_weights(), 1e-6)