if args.use_morph_network: Gz.pretrain_morph_network() listeners = [ LossReporter(), AEImageSampleLogger(output_path, valid_dataset, args, folder_name="AE_samples_valid", print_stats=True), AEImageSampleLogger(output_path, dataset, args, folder_name="AE_samples_train"), MorphImageLogger(output_path, valid_dataset, args, slerp=args.use_slerp), ModelSaver(output_path, n=1, overwrite=True, print_output=True), LossPlotter(output_path) ] if args.use_dis_l_reconstruction_loss: rec_loss = "dis_l" elif args.use_frs_reconstruction_loss: rec_loss = "frs" else: rec_loss = "pixelwise" if args.use_dis_l_morph_loss: morph_loss = "dis_l" elif args.use_frs_morph_loss: morph_loss = "frs" else:
dataset = CelebaCropped(split="train", download=True, morgan_like_filtering=True, transform=transforms.Compose([ transforms.ToTensor(), ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) G = Generator64(args.l_size, args.h_size, args.use_mish, n_channels=3, sigmoid_out=True, use_lr_norm=args.use_lr_norm) D = Discriminator64(args.h_size, use_bn=False, use_mish=args.use_mish, n_channels=3, dropout=args.dropout_rate, use_logits=True) G_optimizer = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(0.0, 0.9)) D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.0, 0.9)) if args.cuda: G = G.cuda() D = D.cuda() D.init_weights() listeners = [ LossReporter(), GanImageSampleLogger(output_path, args, pad_value=1, n_images=6*6), ModelSaver(output_path, n=5, overwrite=True, print_output=True), ModelSaver(output_path, n=20, overwrite=False, print_output=True) ] train_loop = GanTrainLoop(listeners, G, D, G_optimizer, D_optimizer, dataloader, D_steps_per_G_step=args.d_steps, cuda=args.cuda, epochs=args.epochs, lambd=args.lambd) train_loop.train()