コード例 #1
0
def train_tnt_and_reference_models(model_config,
                                   optimizer,
                                   micro_batch_size,
                                   nbatches,
                                   number_epochs,
                                   optimizer_kwargs={}):
    (train_dataset,
     _) = util.train_test_mnist_datasets(nbatches=nbatches,
                                         micro_batch_size=micro_batch_size)
    (ref_train_dataset,
     _) = util.train_test_mnist_datasets(nbatches=nbatches,
                                         micro_batch_size=micro_batch_size)
    tnt_model_runner, ref_model_runner = get_compiled_models(
        model_config, optimizer, **optimizer_kwargs)

    tnt_history = tnt_model_runner.train_model(train_dataset, number_epochs)
    ref_history = ref_model_runner.train_model(ref_train_dataset,
                                               number_epochs)

    rank = tnt.get_rank()
    logging.getLogger().info(f"[Rank {rank}] Tarantella (loss, accuracy) = "
                             f"({tnt_history.history})")
    logging.getLogger().info(f"[Rank {rank}] Reference (loss, accuracy) = "
                             f"({ref_history.history})")
    return tnt_history, ref_history
コード例 #2
0
    def test_save_load_train_models(self, model, save_setup, parallel_strategy,
                                    optimizer_type,
                                    check_configuration_identical):
        train_dataset, _ = util.train_test_mnist_datasets(nbatches=10,
                                                          micro_batch_size=32,
                                                          shuffle=False)
        # create and train model
        tnt_model = get_tnt_model_compiled(model, parallel_strategy,
                                           optimizer_type())
        tnt_model.fit(train_dataset, epochs=2, verbose=0)
        tnt_model.save(save_setup['save_dir'],
                       tnt_save_all_devices=save_setup['all_devices'])

        # load into a new tnt.Model
        reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir'])
        assert isinstance(reloaded_tnt_model, tnt.Model)
        check_configuration_identical(reloaded_tnt_model, tnt_model)

        # continue training on the original model
        tnt_model.fit(train_dataset, epochs=2, verbose=0)

        # continue training on the loaded model
        reloaded_tnt_model.fit(train_dataset, epochs=2, verbose=0)

        util.compare_weights(reloaded_tnt_model.get_weights(),
                             tnt_model.get_weights(), 1e-6)
コード例 #3
0
    def test_keras_models(self, model, save_setup, optimizer,
                          check_configuration_identical):
        train_dataset, _ = util.train_test_mnist_datasets(nbatches=10,
                                                          micro_batch_size=32,
                                                          shuffle=False)
        # train model
        keras_model = model
        keras_model.compile(
            optimizer(),
            loss=keras.losses.SparseCategoricalCrossentropy(),
            metrics=[keras.metrics.SparseCategoricalAccuracy()])
        keras_model.fit(train_dataset, epochs=2, verbose=0)
        keras_model.save(save_setup['save_dir'])

        reloaded_model = keras.models.load_model(save_setup['save_dir'])
        check_configuration_identical(reloaded_model, keras_model)

        # continue training on the original model
        keras_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0)

        # continue training on the loaded model
        reloaded_model.fit(train_dataset, epochs=2, shuffle=False, verbose=0)

        util.compare_weights(reloaded_model.get_weights(), keras.get_weights(),
                             1e-6)
コード例 #4
0
    def test_compare_accuracy_against_reference(self, model_runners,
                                                micro_batch_size,
                                                number_epochs, nbatches,
                                                test_nbatches,
                                                remainder_samples_per_batch,
                                                last_incomplete_batch_size):
        (train_dataset, test_dataset) = util.train_test_mnist_datasets(
            nbatches=nbatches,
            test_nbatches=test_nbatches,
            micro_batch_size=micro_batch_size,
            shuffle=False,
            remainder_samples_per_batch=remainder_samples_per_batch,
            last_incomplete_batch_size=last_incomplete_batch_size)
        (ref_train_dataset, ref_test_dataset) = util.train_test_mnist_datasets(
            nbatches=nbatches,
            test_nbatches=test_nbatches,
            micro_batch_size=micro_batch_size,
            shuffle=False,
            remainder_samples_per_batch=remainder_samples_per_batch,
            last_incomplete_batch_size=last_incomplete_batch_size)
        tnt_model_runner, reference_model_runner = model_runners

        reference_model_runner.train_model(ref_train_dataset, number_epochs)
        tnt_model_runner.train_model(train_dataset, number_epochs)

        tnt_loss_accuracy = tnt_model_runner.evaluate_model(test_dataset)
        ref_loss_accuracy = reference_model_runner.evaluate_model(
            ref_test_dataset)

        rank = tnt.get_rank()
        logging.getLogger().info(
            f"[Rank {rank}] Tarantella[loss, accuracy] = {tnt_loss_accuracy}")
        logging.getLogger().info(
            f"[Rank {rank}] Reference [loss, accuracy] = {ref_loss_accuracy}")

        result = [True, True]
        if tnt.is_master_rank():
            result = [
                np.isclose(tnt_loss_accuracy[0],
                           ref_loss_accuracy[0],
                           atol=1e-2),  # losses might not be identical
                np.isclose(tnt_loss_accuracy[1],
                           ref_loss_accuracy[1],
                           atol=1e-6)
            ]
        util.assert_on_all_ranks(result)
コード例 #5
0
    def test_load_model_with_compile_flag(self, model, save_setup,
                                          parallel_strategy,
                                          load_compiled_model):
        tnt_model = get_tnt_model_compiled(model, parallel_strategy)
        train_dataset, _ = util.train_test_mnist_datasets(nbatches=1,
                                                          micro_batch_size=32)

        tnt_model.save(save_setup['save_dir'],
                       tnt_save_all_devices=save_setup['all_devices'])
        reloaded_tnt_model = tnt.models.load_model(save_setup['save_dir'],
                                                   compile=load_compiled_model)

        if not load_compiled_model:
            # if the model is not compiled, training should not succeed
            with pytest.raises(RuntimeError):
                reloaded_tnt_model.fit(train_dataset, epochs=1, verbose=0)
        else:  # load compiled model
            # should be able to train the model if it was previously compiled
            reloaded_tnt_model.fit(train_dataset, epochs=1, verbose=0)
コード例 #6
0
    def test_save_weights_after_training(self, model, save_setup,
                                         parallel_strategy, tf_format):
        # create un-shuffling dataset to be able to continue training identically
        # both on the `tnt.Model` and then on the `keras.Model`
        train_dataset, _ = util.train_test_mnist_datasets(nbatches=10,
                                                          micro_batch_size=32,
                                                          shuffle=False)
        # create and train model
        tnt_model = get_tnt_model_compiled(model, parallel_strategy)
        tnt_model.fit(train_dataset, epochs=1, verbose=0)

        os.makedirs(save_setup['save_dir'], exist_ok=True)
        save_path = os.path.join(save_setup['save_dir'], "weight")
        if not tf_format:
            save_path = save_path + ".h5"

        tnt_model.save_weights(save_path,
                               tnt_save_all_devices=save_setup['all_devices'])

        # create new model with same architecture and optimizer
        if isinstance(model, tf.keras.Sequential):
            model_from_config = tnt.Sequential.from_config(
                tnt_model.get_config())
        elif isinstance(model, tf.keras.Model):
            model_from_config = tnt.models.model_from_config(
                tnt_model.get_config())

        model_from_config.compile(**get_compile_params())
        model_from_config.load_weights(save_path)
        util.compare_weights(tnt_model.get_weights(),
                             model_from_config.get_weights(), 1e-6)

        # using the TF format saves the state together with the weights
        # such that training can continue on the loaded model
        if tf_format:
            tnt_model.fit(train_dataset, epochs=1, verbose=0)
            model_from_config.fit(train_dataset, epochs=1, verbose=0)

            util.compare_weights(tnt_model.get_weights(),
                                 model_from_config.get_weights(), 1e-6)