def test_keras_models(self, model, save_setup, optimizer, check_configuration_identical): train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # train model keras_model = model keras_model.compile( optimizer(), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) keras_model.fit(train_dataset, epochs=2, verbose=0) keras_model.save(save_setup['save_dir']) reloaded_model = keras.models.load_model(save_setup['save_dir']) check_configuration_identical(reloaded_model, keras_model) # continue training on the original model keras_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0) # continue training on the loaded model reloaded_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0) util.compare_weights(reloaded_model.get_weights(), keras.get_weights(), 1e-6)
def test_save_load_train_models(self, model, save_setup, parallel_strategy, optimizer_type, check_configuration_identical): train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # create and train model tnt_model = get_tnt_model_compiled(model, parallel_strategy, optimizer_type()) tnt_model.fit(train_dataset, epochs=2, verbose=0) tnt_model.save(save_setup['save_dir'], tnt_save_all_devices=save_setup['all_devices']) # load into a new tnt.Model reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir']) assert isinstance(reloaded_tnt_model, tnt.Model) check_configuration_identical(reloaded_tnt_model, tnt_model) # continue training on the original model tnt_model.fit(train_dataset, epochs=2, verbose=0) # continue training on the loaded model reloaded_tnt_model.fit(train_dataset, epochs=2, verbose=0) util.compare_weights(reloaded_tnt_model.get_weights(), tnt_model.get_weights(), 1e-6)
def test_save_load_before_training(self, model, save_setup, parallel_strategy, check_configuration_identical): tnt_model = get_tnt_model_compiled(model, parallel_strategy) tnt_model.save(save_setup['save_dir'], tnt_save_all_devices=save_setup['all_devices']) reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir'], compile=True) check_configuration_identical(reloaded_tnt_model, tnt_model) util.compare_weights(reloaded_tnt_model.get_weights(), tnt_model.get_weights(), 1e-6)
def test_compare_accuracy_against_reference(self, tarantella_framework, model_runners, micro_batch_size, number_epochs, nbatches): batch_size = micro_batch_size * tarantella_framework.get_size() nsamples = nbatches * batch_size tnt_model_runner, reference_model_runner = model_runners # reuse model with its initial weights tnt_model_runner.reset_weights() reference_model_runner.reset_weights() # verify that both models have identical weights tnt_initial_weights = tnt_model_runner.get_weights() reference_initial_weights = reference_model_runner.get_weights() util.compare_weights(tnt_initial_weights, reference_initial_weights, 1e-6) # train reference model (ref_train_dataset, ref_test_dataset) = util.load_dataset(mnist.load_mnist_dataset, train_size=nsamples, train_batch_size=batch_size, test_size=10000, test_batch_size=batch_size) reference_model_runner.train_model(ref_train_dataset, number_epochs) reference_loss_accuracy = reference_model_runner.evaluate_model( ref_test_dataset) # train Tarantella model (train_dataset, test_dataset) = util.load_dataset(mnist.load_mnist_dataset, train_size=nsamples, train_batch_size=batch_size, test_size=10000, test_batch_size=batch_size) tnt_model_runner.train_model(train_dataset, number_epochs) tnt_loss_accuracy = tnt_model_runner.evaluate_model(test_dataset) rank = tarantella_framework.get_rank() logging.getLogger().info("[Rank %d] Tarantella[loss, accuracy] = %s" % (rank, str(tnt_loss_accuracy))) logging.getLogger().info("[Rank %d] Reference [loss, accuracy] = %s" % (rank, str(reference_loss_accuracy))) assert np.isclose(tnt_loss_accuracy[0], reference_loss_accuracy[0], atol=1e-2) # losses might not be identical assert np.isclose(tnt_loss_accuracy[1], reference_loss_accuracy[1], atol=1e-2)
def test_partition_core_models(self, model_and_partitions): num_micro_batches = 1 model, partition_gen, expected_num_partitions, _, expected_model_gen = model_and_partitions rank_mapper = rmapper.RankMapper( num_ranks=expected_num_partitions, pipeline_graph=partition_gen.get_pipeline_graph()) for rank in range(expected_num_partitions): partition_id = rank_mapper.get_partition_for_rank(rank) partition_graph = partition_gen.get_partition_graph(partition_id) cm_builder = core_model_builder.CoreModelBuilder( model, partition_id, partition_graph) core_model = cm_builder.get_model() reference_core_model = expected_model_gen(rank) utils.check_model_configuration_identical(core_model, reference_core_model) utils.compare_weights(core_model.get_weights(), reference_core_model.get_weights(), 1e-6)
def test_compare_weights_across_ranks(self, tarantella_framework, model_runner, micro_batch_size, nbatches, number_epochs): comm_size = tarantella_framework.get_size() batch_size = micro_batch_size * comm_size nsamples = nbatches * batch_size (train_dataset, _) = util.load_dataset(mnist.load_mnist_dataset, train_size=nsamples, train_batch_size=batch_size, test_size=0, test_batch_size=batch_size) model_runner.reset_weights() model_runner.train_model(train_dataset, number_epochs) final_weights = model_runner.get_weights() # broadcast the weights from the master rank to all the participating ranks model_runner.model._broadcast_weights() reference_rank_weights = model_runner.get_weights() util.compare_weights(final_weights, reference_rank_weights, 1e-6)
def test_save_weights_after_training(self, model, save_setup, parallel_strategy, tf_format): # create un-shuffling dataset to be able to continue training identically # both on the `tnt.Model` and then on the `keras.Model` train_dataset, _ = util.train_test_mnist_datasets(nbatches=10, micro_batch_size=32, shuffle=False) # create and train model tnt_model = get_tnt_model_compiled(model, parallel_strategy) tnt_model.fit(train_dataset, epochs=1, verbose=0) os.makedirs(save_setup['save_dir'], exist_ok=True) save_path = os.path.join(save_setup['save_dir'], "weight") if not tf_format: save_path = save_path + ".h5" tnt_model.save_weights(save_path, tnt_save_all_devices=save_setup['all_devices']) # create new model with same architecture and optimizer if isinstance(model, tf.keras.Sequential): model_from_config = tnt.Sequential.from_config( tnt_model.get_config()) elif isinstance(model, tf.keras.Model): model_from_config = tnt.models.model_from_config( tnt_model.get_config()) model_from_config.compile(**get_compile_params()) model_from_config.load_weights(save_path) util.compare_weights(tnt_model.get_weights(), model_from_config.get_weights(), 1e-6) # using the TF format saves the state together with the weights # such that training can continue on the loaded model if tf_format: tnt_model.fit(train_dataset, epochs=1, verbose=0) model_from_config.fit(train_dataset, epochs=1, verbose=0) util.compare_weights(tnt_model.get_weights(), model_from_config.get_weights(), 1e-6)
def test_simple_keras_models(self, save_setup, optimizer): def get_model(): # Create a simple model. inputs = keras.Input(shape=(32, )) outputs = keras.layers.Dense(1)(inputs) model = keras.Model(inputs, outputs) model.compile(optimizer=optimizer(), loss="mean_squared_error") return model tf.random.set_seed(42) model = get_model() # Train the model. test_input = np.ones((128, 32)) test_target = np.ones((128, 1)) model.fit(test_input, test_target) model.save(save_setup['save_dir']) tf.random.set_seed(42) reconstructed_model = keras.models.load_model(save_setup['save_dir']) # Let's check: np.testing.assert_allclose(model.predict(test_input), reconstructed_model.predict(test_input)) util.compare_weights(reconstructed_model.get_weights(), model.get_weights(), 1e-6) tf.random.set_seed(42) reconstructed_model.fit(test_input, test_target, epochs=3, shuffle=False) tf.random.set_seed(42) model.fit(test_input, test_target, epochs=3, shuffle=False) util.compare_weights(model.get_weights(), reconstructed_model.get_weights(), 1e-6)