Gz = Gz.cuda()
    Gx = Gx.cuda()
    D = D.cuda()

Gz.init_weights()
Gx.init_weights()
D.init_weights()

if args.use_morph_network:
    Gz.pretrain_morph_network()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path,
                        valid_dataset,
                        args,
                        folder_name="AE_samples_valid",
                        print_stats=True),
    AEImageSampleLogger(output_path,
                        dataset,
                        args,
                        folder_name="AE_samples_train"),
    MorphImageLogger(output_path, valid_dataset, args, slerp=args.use_slerp),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True),
    LossPlotter(output_path)
]

if args.use_dis_l_reconstruction_loss:
    rec_loss = "dis_l"
elif args.use_frs_reconstruction_loss:
    rec_loss = "frs"
Example #2
0
Gz_optimizer = torch.optim.Adam(Gz.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
Gx_optimizer = torch.optim.Adam(Gx.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.999))

Gz.init_weights()
Gx.init_weights()
D.init_weights()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path,
                        valid_dataset,
                        args,
                        folder_name="AE_samples_valid"),
    AEImageSampleLogger(output_path,
                        valid_dataset,
                        args,
                        folder_name="AE_samples_valid_train_mode",
                        eval_mode=False),
    AEImageSampleLogger(output_path,
                        dataset,
                        args,
                        folder_name="AE_samples_train"),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True)
]
train_loop = VAEGANTrainLoop(
    listeners=listeners,
    Gz=Gz,
Example #3
0
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Lambda(lambda img: img * 2 - 1)
]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=12)


enc = VAEGANEncoder28(args.l_size, args.h_size, n_channels=3)
dec = VAEGANGenerator28(args.l_size, args.h_size, n_channels=3)
enc_optimizer = torch.optim.Adam(enc.parameters(), lr=args.lr, betas=(0.5, 0.999))
dec_optimizer = torch.optim.Adam(dec.parameters(), lr=args.lr, betas=(0.5, 0.999))

if args.cuda:
    enc = enc.cuda()
    dec = dec.cuda()

enc.init_weights()
dec.init_weights()

listeners = [
    LossReporter(),
    # GanImageSampleLogger(output_path, args, pad_value=1),
    AEImageSampleLogger(output_path, valid_dataset, args),
    ModelSaver(output_path, n=5, overwrite=True, print_output=True)
]
train_loop = VaeTrainLoop(listeners, enc, dec, enc_optimizer, dec_optimizer, dataloader,
                          cuda=args.cuda, epochs=args.epochs, beta=args.beta)

train_loop.train()
Example #4
0
                               betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.999))

if args.cuda:
    Gz = Gz.cuda()
    Gx = Gx.cuda()
    D = D.cuda()

Gz.init_weights()
Gx.init_weights()
D.init_weights()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path,
                        dataset,
                        args,
                        folder_name="AE_samples_train"),
    ModelSaver(output_path, n=1, overwrite=True, print_output=True)
]
train_loop = ALITrainLoop(listeners=listeners,
                          Gz=Gz,
                          Gx=Gx,
                          D=D,
                          optim_G=G_optimizer,
                          optim_D=D_optimizer,
                          dataloader=dataloader,
                          cuda=args.cuda,
                          epochs=args.epochs,
                          d_img_noise_std=0.1,
                          decrease_noise=True,
                          use_sigmoid=True)
Example #5
0
    frs_model.eval()
    if args.cuda:
        frs_model = frs_model.cuda()

if args.cuda:
    Gz = Gz.cuda()
    Gx = Gx.cuda()
    D = D.cuda()

Gz.init_weights()
Gx.init_weights()
D.init_weights()

listeners = [
    LossReporter(),
    AEImageSampleLogger(output_path, valid_dataset, args, folder_name="AE_samples_valid", print_stats=True, every_n_epochs=10),
    AEImageSampleLogger(output_path, dataset, args, folder_name="AE_samples_train", every_n_epochs=10),
    # DiscriminatorOverfitMonitor(dataset, valid_dataset, 100, args),
    ModelSaver(output_path, n=10, overwrite=True, print_output=True),
]

reconstruction_loss_mode = "pixelwise" if not args.use_dis_l_reconstruction_loss else "dis_l"
if frs_model is not None:
    reconstruction_loss_mode = "frs"

train_loop = ALITrainLoop(
    listeners=listeners,
    Gz=Gz,
    Gx=Gx,
    D=D,
    optim_G=G_optimizer,