示例#1
0
                        folder_name="AE_samples_train"),
    # DiscriminatorOverfitMonitor(dataset, valid_dataset, 100, args),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True),
]

reconstruction_loss_mode = "pixelwise" if not args.use_dis_l_reconstruction_loss else "dis_l"
if frs_model is not None:
    reconstruction_loss_mode = "frs"

train_loop = ALITrainLoop(listeners=listeners,
                          Gz=Gz,
                          Gx=Gx,
                          D=D,
                          optim_G=G_optimizer,
                          optim_D=D_optimizer,
                          dataloader=dataloader,
                          cuda=args.cuda,
                          epochs=args.epochs,
                          morgan_alpha=args.morgan_alpha,
                          d_real_label=args.d_real_label,
                          d_img_noise_std=args.instance_noise_std,
                          decrease_noise=True,
                          use_sigmoid=True,
                          reconstruction_loss_mode=reconstruction_loss_mode,
                          frs_model=frs_model,
                          r1_reg_gamma=args.r1_gamma,
                          non_saturating_G_loss=args.ns_gan,
                          disable_D_limiting=args.no_D_limit)

train_loop.train()
示例#2
0
    Gx = Gx.cuda()
    D = D.cuda()

Gz.init_weights()
Gx.init_weights()
D.init_weights()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path, valid_dataset, args, folder_name="AE_samples_valid"),
    AEImageSampleLogger(output_path, dataset, args, folder_name="AE_samples_train"),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True),
    KillSwitchListener(output_path)
]
train_loop = ALITrainLoop(
    listeners=listeners,
    Gz=Gz,
    Gx=Gx,
    D=D,
    optim_G=G_optimizer,
    optim_D=D_optimizer,
    dataloader=dataloader,
    cuda=args.cuda,
    epochs=args.epochs,
    d_img_noise_std=0.1,
    decrease_noise=True,
    use_sigmoid=True
)

train_loop.train()
示例#3
0
        valid,
        output_reproductions=True,
        discriminator_output=True,
        cuda=args.cuda,
        sample_reconstructions=True,
        every_n_epochs=10,
        output_latent=True,
        output_grad_norm=True,
        ns_gan=args.ns_gan,
    )
]

trainloop = ALITrainLoop(
    listeners,
    Gz,
    Gx,
    D,
    G_optimizer,
    D_optimizer,
    dataloader,
    cuda=args.cuda,
    epochs=args.epochs,
    morgan_alpha=args.morgan_alpha,
    d_img_noise_std=args.instance_noise_std,
    decrease_noise=True,
    r1_reg_gamma=args.r1_gamma,
    non_saturating_G_loss=args.ns_gan,
    disable_D_limiting=args.no_D_limit
)

trainloop.train()
示例#4
0
    D = D.cuda()

Gz.init_weights()
Gx.init_weights()
D.init_weights()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path, valid_dataset, args, folder_name="AE_samples_valid",  print_stats=True),
    AEImageSampleLogger(output_path, dataset, args, folder_name="AE_samples_train"),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True)
]
train_loop = ALITrainLoop(
    listeners=listeners,
    Gz=Gz,
    Gx=Gx,
    D=D,
    optim_G=G_optimizer,
    optim_D=D_optimizer,
    dataloader=dataloader,
    cuda=args.cuda,
    epochs=args.epochs,
    morgan_alpha=args.morgan_alpha,
    d_img_noise_std=0.0,
    use_sigmoid=True,
    reconstruction_loss_mode="pixelwise" if not args.use_dis_l_reconstruction_loss else "dis_l",
    r1_reg_gamma=args.r1_gamma,
)

train_loop.train()
listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path,
                        valid_dataset,
                        args,
                        folder_name="AE_samples_valid",
                        print_stats=True),
    AEImageSampleLogger(output_path,
                        dataset,
                        args,
                        folder_name="AE_samples_train"),
    # DiscriminatorOverfitMonitor(dataset, valid_dataset, 100, args),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True),
]
train_loop = ALITrainLoop(listeners=listeners,
                          Gz=Gz,
                          Gx=Gx,
                          D=D,
                          optim_G=G_optimizer,
                          optim_D=D_optimizer,
                          dataloader=dataloader,
                          cuda=args.cuda,
                          epochs=args.epochs,
                          morgan_alpha=args.morgan_alpha,
                          d_real_label=args.d_real_label,
                          d_img_noise_std=args.instance_noise_std,
                          decrease_noise=True,
                          use_sigmoid=True)

train_loop.train()