Beispiel #1
0
def test_val_step(model_type, scale_factor: int):
    """Check we can perform validation step on model"""
    img_size = (16, 16)
    hparams = generate_hparams(model_type, scale_factor=scale_factor)
    model = model_type(hparams)
    inputs = generate_img_batch(1, 3, img_size[1], img_size[0])
    targets = generate_img_batch(1, 3, img_size[1] * scale_factor,
                                 img_size[0] * scale_factor)

    output = model.validation_step((inputs, targets), 0)

    assert 'val_loss' in output
Beispiel #2
0
def test_srgan_train_step(scale_factor: int, opt_index: int):
    """Check we can perform train step on gan model"""
    img_size = (48, 48)
    output_size = (img_size[0] * scale_factor, img_size[1] * scale_factor)

    hparams = generate_hparams(SrGan, img_shape=output_size, scale_factor=scale_factor)
    model = SrGan(hparams)

    inputs = generate_img_batch(1, hparams.in_channels, img_size[1], img_size[0])
    targets = generate_img_batch(1, hparams.in_channels, img_size[1] * scale_factor,
                                 img_size[0] * scale_factor)

    output = model.training_step((inputs, targets), 0, opt_index)

    assert 'loss' in output
Beispiel #3
0
def test_discriminator_train_step():
    """Check we can perform train step on model"""
    img_size = (96, 96)
    hparams = generate_hparams(Discriminator, img_shape=img_size)
    model = Discriminator(hparams)
    inputs = generate_img_batch(1, 3, img_size[1], img_size[0])
    targets = torch.ones(inputs.size(0)).view(-1, 1)

    output = model.training_step((inputs, targets), 0)

    assert 'loss' in output
Beispiel #4
0
def test_output_scale(model_type, scale_factor: int, img_size: Tuple[int,
                                                                     int]):
    """Check model output has corrext dimensions"""

    hparams = generate_hparams(model_type, scale_factor=scale_factor)
    model = model_type(hparams)
    batch = generate_img_batch(1, 3, img_size[1], img_size[0])
    output = model(batch)

    assert output.size(2) == img_size[0] * scale_factor
    assert output.size(3) == img_size[1] * scale_factor
Beispiel #5
0
def test_output_has_correct_num_channels(module_type, in_channels, out_channels):
    module = module_type(in_channels=in_channels, out_channels=out_channels)
    batch = generate_img_batch(1, in_channels, 16, 16)
    output = module(batch)

    assert output.size(1) == out_channels