def create_model(options):
    latent_size = 0

    discriminator = apply_spectral_norm(Discriminator())
    #discriminator = Discriminator()
    generator = Generator()

    lr_base = 0.0002  #* 10 #* 0.1

    model = Gan(
        discriminator=discriminator,
        generator=generator,
        latent_size=latent_size,
        optimizer_discriminator_fn=functools.partial(optimizer_fn,
                                                     lr=lr_base),  # TTUR
        optimizer_generator_fn=functools.partial(optimizer_fn, lr=lr_base),
        real_image_from_batch_fn=get_image,
        image_pool=GanDataPool(100))

    model.apply(init_weights)
    return model
Beispiel #2
0
def create_model():
    latent_size = 64

    generator = Generator(latent_size=latent_size)
    discriminator = Discriminator()

    optimizer_fn = functools.partial(torch.optim.Adam,
                                     lr=0.001,
                                     betas=(0.5, 0.999))

    model = Gan(discriminator=discriminator,
                generator=generator,
                latent_size=latent_size,
                optimizer_discriminator_fn=optimizer_fn,
                optimizer_generator_fn=optimizer_fn,
                real_image_from_batch_fn=get_image)

    return model
Beispiel #3
0
def create_model():
    latent_size = 0

    #netD = apply_spectral_norm(Discriminator())
    discriminator = Discriminator()
    generator = Generator()

    lr_base = 0.0002

    model = Gan(
        discriminator=discriminator,
        generator=generator,
        latent_size=latent_size,
        optimizer_discriminator_fn=functools.partial(optimizer_fn,
                                                     lr=lr_base),  # TTUR
        optimizer_generator_fn=functools.partial(optimizer_fn, lr=lr_base),
        real_image_from_batch_fn=get_image,
    )

    return model