Example #1
0
def test_losses(loss_type: AdversarialLossType, adversarial_logdir: str):
    """
    Test the integration between losses and trainer
    """

    # Losses
    generator_loss = get_adversarial_loss_generator(loss_type)()
    discriminator_loss = get_adversarial_loss_discriminator(loss_type)()

    fake_training_loop(
        adversarial_logdir,
        generator_loss=generator_loss,
        discriminator_loss=discriminator_loss,
    )
Example #2
0
def test_counter_callback(_models, adversarial_logdir):
    clbk = FakeCounterCallback(
        event=Event.ON_EPOCH_END,
        name="TestCounterCallback",
        fn=lambda context: print("Bloop"),
    )
    callbacks = [clbk]
    generator, discriminator = _models
    fake_training_loop(
        adversarial_logdir,
        callbacks=callbacks,
        generator=generator,
        discriminator=discriminator,
        epochs=1,
    )
    assert clbk.fake_counter == 1
Example #3
0
def test_metrics(adversarial_logdir: str):
    """
    Test the integration between metrics and trainer
    """
    # test parameters
    image_resolution = (256, 256)

    metrics = [
        SlicedWassersteinDistance(logdir=adversarial_logdir,
                                  resolution=image_resolution[0]),
        SSIM_Multiscale(logdir=adversarial_logdir),
        InceptionScore(
            # Fake inception model
            ConvDiscriminator(
                layer_spec_input_res=(299, 299),
                layer_spec_target_res=(7, 7),
                kernel_size=(5, 5),
                initial_filters=16,
                filters_cap=32,
                output_shape=10,
            ),
            logdir=adversarial_logdir,
        ),
    ]

    fake_training_loop(
        adversarial_logdir,
        metrics=metrics,
        image_resolution=image_resolution,
        layer_spec_input_res=(8, 8),
        layer_spec_target_res=(8, 8),
        channels=3,
    )

    # assert there exists folder for each metric
    for metric in metrics:
        metric_dir = os.path.join(adversarial_logdir, "best", metric.name)
        assert os.path.exists(metric_dir)
        json_path = os.path.join(metric_dir, f"{metric.name}.json")
        assert os.path.exists(json_path)
        with open(json_path, "r") as fp:
            metric_data = json.load(fp)

            # assert the metric data contains the expected keys
            assert metric.name in metric_data
            assert "step" in metric_data
Example #4
0
def _test_save_callback_helper(adversarial_logdir, save_format,
                               save_sub_format, save_dir):
    image_resolution = (28, 28)
    layer_spec_input_res = (7, 7)
    layer_spec_target_res = (7, 7)
    kernel_size = 5
    channels = 1

    # model definition
    generator = ConvGenerator(
        layer_spec_input_res=layer_spec_input_res,
        layer_spec_target_res=image_resolution,
        kernel_size=kernel_size,
        initial_filters=32,
        filters_cap=16,
        channels=channels,
    )

    discriminator = ConvDiscriminator(
        layer_spec_input_res=image_resolution,
        layer_spec_target_res=layer_spec_target_res,
        kernel_size=kernel_size,
        initial_filters=16,
        filters_cap=32,
        output_shape=1,
    )

    callbacks = [
        SaveCallback(
            models=[generator, discriminator],
            save_dir=save_dir,
            verbose=1,
            save_format=save_format,
            save_sub_format=save_sub_format,
        )
    ]

    fake_training_loop(
        adversarial_logdir,
        callbacks=callbacks,
        generator=generator,
        discriminator=discriminator,
    )
Example #5
0
def test_custom_callbacks(adversarial_logdir: str, event: Event):
    """Test the integration between a custom callback and a trainer."""
    m_callback = MCallback(event)
    callbacks = [m_callback]

    epochs = 2
    dataset_size = 2
    batch_size = 2

    fake_training_loop(
        adversarial_logdir,
        callbacks=callbacks,
        epochs=epochs,
        dataset_size=dataset_size,
        batch_size=batch_size,
    )

    # assert the number of times the on_event has been called
    assert m_callback.counter == get_n_events_from_epochs(
        event, epochs, dataset_size, batch_size)
Example #6
0
def test_callbacks(adversarial_logdir: str):
    """Test the integration between callbacks and trainer."""

    callbacks = [LogImageGANCallback(event=Event.ON_BATCH_END, event_freq=1)]

    fake_training_loop(adversarial_logdir, callbacks=callbacks)