def train_one_epoch(epoch, dataloader, gen, critic, opt_gen, opt_critic, fixed_noise, device, metric_logger, num_samples, freq=100): """ Train one epoch :param epoch: ``int`` current epoch :param dataloader: object of dataloader :param gen: Generator model :param critic: Discriminator model :param opt_gen: Optimizer for generator :param opt_critic: Optimizer for discriminator :param fixed_noise: ``tensor[[cfg.BATCH_SIZE, latent_space_dimension, 1, 1]]`` fixed noise (latent space) for image metrics :param device: cuda device or cpu :param metric_logger: object of MetricLogger :param num_samples: ``int`` well retrievable sqrt() (for example: 4, 16, 64) for good result, number of samples for grid image metric :param freq: ``int``, freq < len(dataloader)`` freq for display results """ for batch_idx, img in enumerate(dataloader): real = img.to(device) # Train critic acc_real = 0 acc_fake = 0 for _ in range(cfg.CRITIC_ITERATIONS): noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device) fake = gen(noise) critic_real = critic(real).reshape(-1) acc_real += critic_real critic_fake = critic(fake).reshape(-1) acc_fake += critic_fake gp = gradient_penalty(critic, real, fake, device=device) loss_critic = -1 * (torch.mean(critic_real) - torch.mean(critic_fake)) + cfg.LAMBDA_GP * gp critic.zero_grad() loss_critic.backward(retain_graph=True) opt_critic.step() acc_real = acc_real / cfg.CRITIC_ITERATIONS acc_fake = acc_fake / cfg.CRITIC_ITERATIONS # Train generator: minimize -E[critic(gen_fake)] output = critic(fake).reshape(-1) loss_gen = -1 * torch.mean(output) gen.zero_grad() loss_gen.backward() opt_gen.step() # logs metrics if batch_idx % freq == 0: with torch.no_grad(): metric_logger.log(loss_critic, loss_gen, acc_real, acc_fake) fake = gen(fixed_noise) metric_logger.log_image(fake, num_samples, epoch, batch_idx, len(dataloader)) metric_logger.display_status(epoch, cfg.NUM_EPOCHS, batch_idx, len(dataloader), loss_critic, loss_gen, acc_real, acc_fake)
cfg.NUM_EPOCHS = end_epoch else: print("=> Init default weights of models and fixed noise") # FIXME sometime (usually) when the weights is initialized from normal distribution can cause mode collapse # init_weights(gen) # init_weights(disc) # defining optimizers after init weights opt_gen = optim.Adam(gen.parameters(), lr=cfg.LEARNING_RATE, betas=(0.5, 0.999)) opt_critic = optim.Adam(critic.parameters(), lr=cfg.LEARNING_RATE, betas=(0.5, 0.999)) start_epoch = 1 end_epoch = cfg.NUM_EPOCHS fixed_noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device) if args.resume_id: metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME, resume_id=args.resume_id) else: metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME) # gradients metric wandb.watch(gen) wandb.watch(critic) # model mode gen.train() critic.train() start_time = time.time()
def train_one_epoch(epoch, dataloader, gen, disc, criterion, opt_gen, opt_disc, fixed_noise, device, metric_logger, num_samples, freq=100): """ Train one epoch :param epoch: ``int`` current epoch :param dataloader: object of dataloader :param gen: Generator model :param disc: Discriminator model :param criterion: Loss function (for this case: binary cross entropy) :param opt_gen: Optimizer for generator :param opt_disc: Optimizer for discriminator :param fixed_noise: ``tensor[[cfg.BATCH_SIZE, latent_space_dimension, 1, 1]]`` fixed noise (latent space) for image metrics :param device: cuda device or cpu :param metric_logger: object of MetricLogger :param num_samples: ``int`` well retrievable sqrt() (for example: 4, 16, 64) for good result, number of samples for grid image metric :param freq: ``int``, must be < len(dataloader)`` freq for display results """ for batch_idx, img in enumerate(dataloader): real = img.to(device) noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device) fake = gen(noise) # Train discriminator: We maximize log(D(x)) + log(1 - D(G(z)) disc_real = disc(real).reshape(-1) loss_disc_real = criterion(disc_real, torch.ones_like(disc_real)) disc_fake = disc(fake.detach()).reshape(-1) loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) loss_disc = (loss_disc_real + loss_disc_fake) / 2 disc.zero_grad() loss_disc.backward() opt_disc.step() # Train generator: We minimize log(1 - D(G(z))). This is the same as maximize log(D(G(z)) output = disc(fake).reshape(-1) loss_gen = criterion(output, torch.ones_like(output)) gen.zero_grad() loss_gen.backward() opt_gen.step() # logs metrics if batch_idx % freq == 0: with torch.no_grad(): metric_logger.log(loss_disc, loss_gen, disc_real, disc_fake) fake = gen(fixed_noise) metric_logger.log_image(fake, num_samples, epoch, batch_idx, len(dataloader)) metric_logger.display_status(epoch, cfg.NUM_EPOCHS, batch_idx, len(dataloader), loss_disc, loss_gen, disc_real, disc_fake)
out_path = 'DCGAN-Anime-Faces' else: out_path = args.out_path if args.device == 'cuda': if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') gen = Generator(128, 3, 64) load_gen(gen, args.path_ckpt, device) gen.eval() if args.grid: noise = get_random_noise(args.num_samples, args.z_size, device) print("==> Generate IMAGE GRID...") output = gen(noise) show_batch(output, out_path, num_samples=args.num_samples, figsize=(args.img_size, args.img_size)) elif args.gif: noise = get_random_noise(args.num_samples, args.z_size, device) print("==> Generate GIF...") images = latent_space_interpolation_sequence(noise, step_interpolation=args.steps) output = gen(images) if args.resize and isinstance(args.resize, int): print(f"==> Resize images to {args.resize}px") output = F.interpolate(output, size=args.resize) images = [] for img in output: img = img.detach().permute(1, 2, 0)
betas=(0.0, 0.99)) opt_dis = optim.Adam(params=dis.parameters(), lr=cfg.LEARNING_RATE, betas=(0.0, 0.99)) # defining gradient scalers for automatic mixed precision scaler_gen = torch.cuda.amp.GradScaler() scaler_dis = torch.cuda.amp.GradScaler() if args.checkpoint: fixed_noise, cfg.START_EPOCH = load_checkpoint(args.checkpoint, gen, opt_gen, scaler_gen, dis, opt_dis, scaler_dis, cfg.LEARNING_RATE) else: fixed_noise = get_random_noise(cfg.FIXED_NOISE_SAMPLES, cfg.Z_DIMENSION, device) # logger.info("load weights from normal distribution") # init_weights(gen) # init_weights(dis) gen.train() dis.train() metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME) for epoch in range(cfg.START_EPOCH, cfg.END_EPOCH): if args.wgp: train_one_epoch_with_gp(gen, opt_gen, dis, opt_dis,