def train(args, epoch, loader, model, optimizer, scheduler): torch.backends.cudnn.benchmark = True model.train() if get_rank() == 0: pbar = tqdm(loader, dynamic_ncols=True) else: pbar = loader for i, (img, annot) in enumerate(pbar): img = img.to('cuda') annot = annot.to('cuda') loss, _ = model(img, annot) loss_sum = loss['loss'] + args.aux_weight * loss['aux'] model.zero_grad() loss_sum.backward() optimizer.step() scheduler.step() loss_dict = reduce_loss_dict(loss) loss = loss_dict['loss'].mean().item() aux_loss = loss_dict['aux'].mean().item() if get_rank() == 0: lr = optimizer.param_groups[0]['lr'] pbar.set_description( f'epoch: {epoch + 1}; loss: {loss:.5f}; aux loss: {aux_loss:.5f}; lr: {lr:.5f}' )
def valid(args, epoch, loader, model, show): torch.backends.cudnn.benchmark = False model.eval() if get_rank() == 0: pbar = tqdm(loader, dynamic_ncols=True) else: pbar = loader intersect_sum = None union_sum = None correct_sum = 0 total_sum = 0 for i, (img, annot) in enumerate(pbar): img = img.to('cuda') annot = annot.to('cuda') _, out = model(img) _, pred = out.max(1) if get_rank() == 0 and i % 10 == 0: result = show(img[0], annot[0], pred[0]) result.save(f'sample/{str(epoch + 1).zfill(3)}-{str(i).zfill(4)}.png') pred = (annot > 0) * pred correct = (pred > 0) * (pred == annot) correct_sum += correct.sum().float().item() total_sum += (annot > 0).sum().float() for g, p, c in zip(annot, pred, correct): intersect, union = intersection_union(g, p, c, args.n_class) if intersect_sum is None: intersect_sum = intersect else: intersect_sum += intersect if union_sum is None: union_sum = union else: union_sum += union all_intersect = sum(all_gather(intersect_sum.to('cpu'))) all_union = sum(all_gather(union_sum.to('cpu'))) if get_rank() == 0: iou = all_intersect / (all_union + 1e-10) m_iou = iou.mean().item() pbar.set_description( f'acc: {correct_sum / total_sum:.5f}; mIoU: {m_iou:.5f}' )
def valid(args, epoch, loader, dataset, model, device): if args.distributed: model = model.module torch.cuda.empty_cache() model.eval() pbar = tqdm(loader, dynamic_ncols=True) preds = {} for images, targets, ids in pbar: model.zero_grad() images = images.to(device) targets = [target.to(device) for target in targets] pred, _ = model(images.tensors, images.sizes) pred = [p.to('cpu') for p in pred] preds.update({id: p for id, p in zip(ids, pred)}) preds = accumulate_predictions(preds) if get_rank() != 0: return evaluate(dataset, preds) return
def valid(args, epoch, loader, dataset, model, device, logger=None): if args.distributed: model = model.module torch.cuda.empty_cache() model.eval() if get_rank() == 0: pbar = tqdm(enumerate(loader), total=len(loader), dynamic_ncols=True) else: pbar = enumerate(loader) preds = {} for idx, (images, targets, ids) in pbar: model.zero_grad() images = images.to(device) targets = [target.to(device) for target in targets] pred, _ = model(images.tensors, images.sizes) pred = [p.to('cpu') for p in pred] preds.update({id: p for id, p in zip(ids, pred)}) preds = accumulate_predictions(preds) if get_rank() != 0: return evl_res = evaluate(dataset, preds) # writing log to tensorboard if logger: log_group_name = "validation" box_result = evl_res['bbox'] logger.add_scalar(log_group_name + '/AP', box_result['AP'], epoch) logger.add_scalar(log_group_name + '/AP50', box_result['AP50'], epoch) logger.add_scalar(log_group_name + '/AP75', box_result['AP75'], epoch) logger.add_scalar(log_group_name + '/APl', box_result['APl'], epoch) logger.add_scalar(log_group_name + '/APm', box_result['APm'], epoch) logger.add_scalar(log_group_name + '/APs', box_result['APs'], epoch) return preds
def train(args, epoch, loader, model, optimizer, device, logger=None): model.train() if get_rank() == 0: pbar = tqdm(enumerate(loader), total=len(loader), dynamic_ncols=True) else: pbar = enumerate(loader) for idx, (images, targets, _) in pbar: model.zero_grad() images = images.to(device) targets = [target.to(device) for target in targets] _, loss_dict = model(images, targets=targets) loss_cls = loss_dict['loss_cls'].mean() loss_box = loss_dict['loss_reg'].mean() loss_center = loss_dict['loss_centerness'].mean() loss = loss_cls + loss_box + loss_center loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 10) optimizer.step() loss_reduced = reduce_loss_dict(loss_dict) loss_cls = loss_reduced['loss_cls'].mean().item() loss_box = loss_reduced['loss_reg'].mean().item() loss_center = loss_reduced['loss_centerness'].mean().item() if get_rank() == 0: pbar.set_description( (f'epoch: {epoch + 1}; cls: {loss_cls:.4f}; ' f'box: {loss_box:.4f}; center: {loss_center:.4f}')) # writing log to tensorboard if logger and idx % 10 == 0: totalStep = (epoch * len(loader) + idx) * args.batch * args.n_gpu logger.add_scalar('training/loss_cls', loss_cls, totalStep) logger.add_scalar('training/loss_box', loss_box, totalStep) logger.add_scalar('training/loss_center', loss_center, totalStep) logger.add_scalar('training/loss_all', (loss_cls + loss_box + loss_center), totalStep)
def set_logger(self): if get_rank() == 0 and wandb is not None and self.args.wandb: wandb.init(project="stylegan 2") else: self.log_dir = '%s/%d' % (self.args.log_dir, self.args.manualSeed) os.makedirs(self.log_dir, exist_ok=True) self.summary = SummaryWriter(log_dir=self.log_dir) with tarfile.open(os.path.join(self.log_dir, 'code.tar.gz'), "w:gz") as tar: for addfile in ['train.py', 'dataset.py', 'model.py']: tar.add(addfile) '''with open(os.path.join(self.log_dir, 'args.txt'), 'w') as f:
def __init__(self, img_root_path, img_keys_path, transform, batch_size, dist_mode: bool = False, rank_seed: Optional[int] = None, with_key: bool = False): self.img_root_path = img_root_path self.img_keys_path = img_keys_path self.transform = transform self.batch_size = batch_size if rank_seed is not None: self.rand = random.Random(rank_seed) else: self.rand = random.Random(time.time()) self.img_keys_file_list = [ os.path.join(self.img_keys_path, f) for f in os.listdir(self.img_keys_path) if not f.startswith('.') ] self.rand.shuffle(self.img_keys_file_list) self.rank = -1 if dist_mode: rank_pic_size = int( math.ceil( len(self.img_keys_file_list) / dist.get_world_size())) self.img_keys_file_list = self.img_keys_file_list[ rank_pic_size * dist.get_rank():rank_pic_size * (dist.get_rank() + 1)] self.rank = dist.get_rank() self.num_examples = max((sum((1 for _ in open(f))) for f in self.img_keys_file_list[:10])) \ * len(self.img_keys_file_list) self.num_itertions = int(math.ceil(self.num_examples / batch_size)) self.with_key = with_key
def get_logit(dataloader, netD, device): data_iter = iter(dataloader) logit_list = np.zeros(len(dataloader.dataset)) netD.eval() with torch.no_grad(): if get_rank() == 0: data_iter = tqdm(data_iter) for data, idx in data_iter: real_data = data.to(device) idx = idx.to(device) logit_r = netD(real_data).view(-1) idx_all = concat_all_gather(idx) logit_r = concat_all_gather(logit_r) logit_list[idx_all.cpu().numpy()] = logit_r.detach().cpu().numpy() netD.train() return logit_list
def accumulate_predictions(predictions): all_predictions = all_gather(predictions) if get_rank() != 0: return predictions = {} for p in all_predictions: predictions.update(p) ids = list(sorted(predictions.keys())) if len(ids) != ids[-1] + 1: print('Evaluation results is not contiguous') predictions = [predictions[i] for i in ids] return predictions
def save_predictions_to_images(dataset, predictions): # if get_rank() != 0: return for id, pred in enumerate(predictions): orig_id = dataset.id2img[id] if len(pred) == 0: continue img_meta = dataset.get_image_meta(id) width = img_meta['width'] height = img_meta['height'] pred = pred.resize((width, height)) boxes = pred.bbox.tolist() scores = pred.get_field('scores').tolist() ids = pred.get_field('labels').tolist() img_name = img_meta['file_name'] img_baseName = os.path.splitext(img_name)[0] # print('saving ' + img_name + ' ...') imgroot = dataset.root show_bbox(imgroot + '/' + img_name, boxes, ids, CLASS_NAME, file_name=img_name, scores=scores) categories = [dataset.id2category[i] for i in ids] for k, box in enumerate(boxes): category_id = categories[k] score = scores[k]
def train(args, dataset, gen, dis, g_ema, device): if args.distributed: g_module = gen.module d_module = dis.module else: g_module = gen d_module = dis vgg = VGGFeature("vgg16", [4, 9, 16, 23, 30], use_fc=True).eval().to(device) requires_grad(vgg, False) g_optim = optim.Adam(gen.parameters(), lr=1e-4, betas=(0, 0.999)) d_optim = optim.Adam(dis.parameters(), lr=1e-4, betas=(0, 0.999)) loader = data.DataLoader( dataset, batch_size=args.batch, num_workers=4, sampler=dist.data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) loader_iter = sample_data(loader) pbar = range(args.start_iter, args.iter) if dist.get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True) eps = 1e-8 for i in pbar: real, class_id = next(loader_iter) real = real.to(device) class_id = class_id.to(device) masks = make_mask(real.shape[0], device, args.crop_prob) features, fcs = vgg(real) features = features + fcs[1:] requires_grad(dis, True) requires_grad(gen, False) real_pred = dis(real, class_id) z = torch.randn(args.batch, args.dim_z, device=device) fake = gen(z, class_id, features, masks) fake_pred = dis(fake, class_id) d_loss = d_ls_loss(real_pred, fake_pred) d_optim.zero_grad() d_loss.backward() d_optim.step() z1 = torch.randn(args.batch, args.dim_z, device=device) z2 = torch.randn(args.batch, args.dim_z, device=device) requires_grad(gen, True) requires_grad(dis, False) masks = make_mask(real.shape[0], device, args.crop_prob) if args.distributed: gen.broadcast_buffers = True fake1 = gen(z1, class_id, features, masks) if args.distributed: gen.broadcast_buffers = False fake2 = gen(z2, class_id, features, masks) fake_pred = dis(fake1, class_id) a_loss = g_ls_loss(None, fake_pred) features_fake, fcs_fake = vgg(fake1) features_fake = features_fake + fcs_fake[1:] r_loss = recon_loss(features_fake, features, masks) div_loss = diversity_loss(z1, z2, fake1, fake2, eps) g_loss = a_loss + args.rec_weight * r_loss + args.div_weight * div_loss g_optim.zero_grad() g_loss.backward() g_optim.step() accumulate(g_ema, g_module) if dist.get_rank() == 0: pbar.set_description( f"d: {d_loss.item():.4f}; g: {a_loss.item():.4f}; rec: {r_loss.item():.4f}; div: {div_loss.item():.4f}" ) if i % 100 == 0: utils.save_image( fake1, f"sample/{str(i).zfill(6)}.png", nrow=int(args.batch**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "args": args, "g_ema": g_ema.state_dict(), "g": g_module.state_dict(), "d": d_module.state_dict(), }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" )) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) # styleModel = StyleTensorEncoder(input_dim=opt.input_channel) # genModel = AdaIN_Tensor_WordGenerator(opt) # disModel = MsImageDisV2(opt) # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none') # mixModel = Mixer(opt,nblk=3, dim=opt.latent) genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) ocrModel = ModelV1(opt).to(device) accumulate(g_ema, genModel, 0) # # weight initialization # for currModel in [styleModel, mixModel]: # for name, param in currModel.named_parameters(): # if 'localization_fc2' in name: # print(f'Skip {name} as it is already initialized') # continue # try: # if 'bias' in name: # init.constant_(param, 0.0) # elif 'weight' in name: # init.kaiming_normal_(param) # except Exception as e: # for batchnorm. # if 'weight' in name: # param.data.fill_(1) # continue if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # vggRecCriterion = torch.nn.L1Loss() # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) if opt.distributed: genModel = torch.nn.parallel.DistributedDataParallel( genModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) disModel = torch.nn.parallel.DistributedDataParallel( disModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) ocrModel = torch.nn.parallel.DistributedDataParallel( ocrModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False ) # styleModel = torch.nn.DataParallel(styleModel).to(device) # styleModel.train() # mixModel = torch.nn.DataParallel(mixModel).to(device) # mixModel.train() # genModel = torch.nn.DataParallel(genModel).to(device) # g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() # disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() # vggModel = torch.nn.DataParallel(vggModel).to(device) # vggModel.eval() # ocrModel = torch.nn.DataParallel(ocrModel).to(device) # if opt.distributed: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() # # ocrModel.module.SequenceModeling.eval() # ocrModel.module.Prediction.eval() # else: # ocrModel.Transformation.eval() # ocrModel.FeatureExtraction.eval() # ocrModel.AdaptiveAvgPool.eval() # # ocrModel.SequenceModeling.eval() # ocrModel.Prediction.eval() ocrModel.eval() if opt.distributed: g_module = genModel.module d_module = disModel.module else: g_module = genModel d_module = disModel g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': if not opt.distributed: ocrModel = torch.nn.DataParallel(ocrModel) print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) #temporary fix if not opt.distributed: ocrModel = ocrModel.module if opt.saved_gen_model !='' and opt.saved_gen_model !='None': print(f'loading pretrained gen model from {opt.saved_gen_model}') checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) genModel.module.load_state_dict(checkpoint['g']) g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disModel.load_state_dict(checkpoint['disModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disModel, True) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) # scInput = mixModel(style,text_2) if 'CTC' in opt.Prediction: images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent) else: images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent) #Domain discriminator: Dis update if opt.augment: image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p) images_recon_2, _ = augment(images_recon_2, ada_aug_p) else: image_gt_tensors_aug = image_gt_tensors fake_pred = disModel(images_recon_2) real_pred = disModel(image_gt_tensors_aug) disCost = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = disCost*opt.disWeight loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_avg_dis.add(disCost) disModel.zero_grad() disCost.backward() dis_optimizer.step() if opt.augment and opt.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device ) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > opt.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: image_gt_tensors.requires_grad = True image_input_tensors.requires_grad = True cat_tensor = image_gt_tensors real_pred = disModel(cat_tensor) r1_loss = d_r1_loss(real_pred, cat_tensor) disModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() loss_dict["r1"] = r1_loss # #[Style Encoder] + [Word Generator] update image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, True) # requires_grad(styleModel, True) # requires_grad(mixModel, True) requires_grad(disModel, False) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent) style = mixing_noise(batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = genModel(style, text_2) else: images_recon_2, _ = genModel(style, text_2[:,1:-1]) if opt.augment: images_recon_2, _ = augment(images_recon_2, ada_aug_p) fake_pred = disModel(images_recon_2) disGenCost = g_nonsaturating_loss(fake_pred) loss_dict["g"] = disGenCost # # #Adversarial loss # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1)) # #Input reconstruction loss # recCost = recCriterion(images_recon_2,image_gt_tensors) # #vgg loss # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) # preds_o = preds_recon[:, :text_1.shape[1], :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2) #predict ocr recognition on generated images # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size) _, preds_recon_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) # preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # cost = opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost cost = opt.disWeight*disGenCost + opt.ocrWeight*ocrCost # styleModel.zero_grad() genModel.zero_grad() # mixModel.zero_grad() disModel.zero_grad() # vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) path_batch_size = max(1, batch_size // opt.path_batch_shrink) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True) style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True) else: images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( images_recon_2, latents, mean_path_length ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() #Individual losses loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(torch.tensor(0.0)) loss_avg_vgg_per.add(torch.tensor(0.0)) loss_avg_vgg_sty.add(torch.tensor(0.0)) loss_avg_ocr.add(opt.ocrWeight*ocrCost) log_r1_val.add(loss_reduced["path"]) log_avg_path_loss_val.add(loss_reduced["path"]) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: # pbar.set_description( # ( # f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " # f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " # f"augment: {ada_aug_p:.4f}" # ) # ) if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # if cntr % 100 == 0: # with torch.no_grad(): # g_ema.eval() # sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]]) # utils.save_image( # sample, # os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"), # nrow=int(opt.n_sample ** 0.5), # normalize=True, # range=(-1, 1), # ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images curr_batch_size = style[0].shape[0] images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png')) else: save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # styleModel.eval() genModel.eval() g_ema.eval() # mixModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v6( iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt) # styleModel.train() genModel.train() # mixModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train OCR loss: {loss_avg_ocr.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ # 'styleModel':styleModel.state_dict(), # 'mixModel':mixModel.state_dict(), 'genModel':g_module.state_dict(), 'g_ema':g_ema.state_dict(), 'disModel':d_module.state_dict(), 'optimizer':optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device, save_dir): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += reduce_sum(ada_augment_data) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 1000 == 0: # save some samples with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, save_dir + f"/samples/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 2000 == 0: #save the model torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, save_dir + f"/checkpoints/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_labels = [] while len(sample_labels) < args.n_sample: real_img, real_label = next(loader) sample_labels.append(real_label.to(device)) sample_labels = torch.cat(sample_labels, 0)[:args.n_sample] for idx in pbar: i = idx + args.start_iter if i > args.iter: print('Done!') break real_img, real_label = next(loader) real_img = real_img.to(device) real_label = real_label.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(real_label, noise) fake_pred = discriminator(real_label, fake_img) real_pred = discriminator(real_label, real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_label, real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(real_label, noise) fake_pred = discriminator(real_label, fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(real_label[:path_batch_size], noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_description(( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' )) if wandb and args.wandb: wandb.log({ 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, }) if i % 200 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema(sample_labels, [sample_z]) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, f'checkpoint/{str(i).zfill(6)}.pt', )
def train(args, loader, generator, discriminator, extra, g_optim, d_optim, e_optim, g_ema, device, g_source, d_source): loader = sample_data(loader) imsave_path = os.path.join('samples', args.exp) model_path = os.path.join('checkpoints', args.exp) if not os.path.exists(imsave_path): os.makedirs(imsave_path) if not os.path.exists(model_path): os.makedirs(model_path) # this defines the anchor points, and when sampling noise close to these, we impose image-level adversarial loss (Eq. 4 in the paper) init_z = torch.randn(args.n_train, args.latent, device=device) pbar = range(args.iter) sfm = nn.Softmax(dim=1) kl_loss = nn.KLDivLoss() sim = nn.CosineSimilarity() if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} g_module = generator d_module = discriminator g_ema_module = g_ema.module accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 # this defines which level feature of the discriminator is used to implement the patch-level adversarial loss: could be anything between [0, args.highp] lowp, highp = 0, args.highp # the following defines the constant noise used for generating images at different stages of training sample_z = torch.randn(args.n_sample, args.latent, device=device) requires_grad(g_source, False) requires_grad(d_source, False) sub_region_z = get_subspace(args, init_z.clone(), vis_flag=True) for idx in pbar: i = idx + args.start_iter which = i % args.subspace_freq # defines whether we sample from anchor region in this iteration or other if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) requires_grad(extra, True) if which > 0: # sample normally, apply patch-level adversarial loss noise = mixing_noise(args.batch, args.latent, args.mixing, device) else: # sample from anchors, apply image-level adversarial loss noise = [get_subspace(args, init_z.clone())] fake_img, _ = generator(noise) if args.augment: real_img, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) fake_pred, _ = discriminator(fake_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) real_pred, _ = discriminator(real_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp), real=True) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() extra.zero_grad() d_loss.backward() d_optim.step() e_optim.step() if args.augment and args.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred, _ = discriminator(real_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) real_pred = real_pred.view(real_img.size(0), -1) real_pred = real_pred.mean(dim=1).unsqueeze(1) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() extra.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() e_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(extra, False) if which > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) else: noise = [get_subspace(args, init_z.clone())] fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred, _ = discriminator(fake_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) g_loss = g_nonsaturating_loss(fake_pred) # distance consistency loss with torch.set_grad_enabled(False): z = torch.randn(args.feat_const_batch, args.latent, device=device) feat_ind = numpy.random.randint(1, g_source.module.n_latent - 1, size=args.feat_const_batch) # computing source distances source_sample, feat_source = g_source([z], return_feats=True) dist_source = torch.zeros( [args.feat_const_batch, args.feat_const_batch - 1]).cuda() # iterating over different elements in the batch for pair1 in range(args.feat_const_batch): tmpc = 0 # comparing the possible pairs for pair2 in range(args.feat_const_batch): if pair1 != pair2: anchor_feat = torch.unsqueeze( feat_source[feat_ind[pair1]][pair1].reshape(-1), 0) compare_feat = torch.unsqueeze( feat_source[feat_ind[pair1]][pair2].reshape(-1), 0) dist_source[pair1, tmpc] = sim(anchor_feat, compare_feat) tmpc += 1 dist_source = sfm(dist_source) # computing distances among target generations _, feat_target = generator([z], return_feats=True) dist_target = torch.zeros( [args.feat_const_batch, args.feat_const_batch - 1]).cuda() # iterating over different elements in the batch for pair1 in range(args.feat_const_batch): tmpc = 0 for pair2 in range( args.feat_const_batch): # comparing the possible pairs if pair1 != pair2: anchor_feat = torch.unsqueeze( feat_target[feat_ind[pair1]][pair1].reshape(-1), 0) compare_feat = torch.unsqueeze( feat_target[feat_ind[pair1]][pair2].reshape(-1), 0) dist_target[pair1, tmpc] = sim(anchor_feat, compare_feat) tmpc += 1 dist_target = sfm(dist_target) rel_loss = args.kl_wt * \ kl_loss(torch.log(dist_target), dist_source) # distance consistency loss g_loss = g_loss + rel_loss loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 # to save up space del rel_loss, g_loss, d_loss, fake_img, fake_pred, real_img, real_pred, anchor_feat, compare_feat, dist_source, dist_target, feat_source, feat_target if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema_module, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.img_freq == 0: with torch.set_grad_enabled(False): g_ema.eval() sample, _ = g_ema([sample_z.data]) sample_subz, _ = g_ema([sub_region_z.data]) utils.save_image( sample, f"%s/{str(i).zfill(6)}.png" % (imsave_path), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) del sample if (i % args.save_freq == 0) and (i > 0): torch.save( { "g_ema": g_ema.state_dict(), # uncomment the following lines only if you wish to resume training after saving. Otherwise, saving just the generator is sufficient for evaluations #"g": g_module.state_dict(), #"g_s": g_source.state_dict(), #"d": d_module.state_dict(), #"g_optim": g_optim.state_dict(), #"d_optim": d_optim.state_dict(), }, f"%s/{str(i).zfill(6)}.pt" % (model_path), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) current_ckpt = args.current_ckpt # generate one fake image to check data correct test_imgs = next(loader) real_grid = utils.make_grid(test_imgs, nrow=2, normalize=True, range=(-1, 1)) wandb.log({"reals": [wandb.Image(real_grid, caption='Real Data')]}) pbar = tqdm(dynamic_ncols=True, smoothing=0.01, initial=current_ckpt + 1, total=args.iter) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator none_g_grads = set() test_in = torch.randn(1, args.latent, device=device) fake, latent = g_module([test_in], return_latents=True) path = g_path_regularize(fake, latent, 0) path[0].backward() for n, p in generator.named_parameters(): if p.grad is None: none_g_grads.add(n) test_in = torch.randn(1, 3, args.size, args.size, requires_grad=True, device=device) pred = d_module(test_in) r1_loss = d_r1_loss(pred, test_in) r1_loss.backward() none_d_grads = set() for n, p in discriminator.named_parameters(): if p.grad is None: none_d_grads.add(n) seed = torch.initial_seed() % 10000000 torch.manual_seed(20) torch.cuda.manual_seed_all(20) sample_z = torch.randn(4 * 4, args.latent, device=device) sample_z_chunks = torch.split(sample_z, args.batch) # reset seed torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) i = current_ckpt + 1 while i < args.iter: real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() set_grad_none(discriminator, none_d_grads) d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise = mixing_noise(args.batch // args.path_batch_shrink, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() set_grad_none(g_module, none_g_grads) g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_postfix(d_loss=f'{d_loss_val:.4f}', g_loss=f'{g_loss_val:.4f}', r1_loss=f'{r1_val:.4f}', path=f'{path_loss_val:.4f}', mean=f'{mean_path_length_avg:.4f}') if wandb and args.wandb: wandb.log({ 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, 'current_ckpt': current_ckpt, 'iteration': i, }) if i % 500 == 0: with torch.no_grad(): g_ema.eval() sample = generate_fake_images(g_ema, sample_z_chunks) if wandb and args.wandb: label = f'{str(i).zfill(8)}.png' image = utils.make_grid(sample, nrow=4, normalize=True, range=(-1, 1)) wandb.log( {"samples": [wandb.Image(image, caption=label)]}) else: utils.save_image( sample, f'sample/{str(i).zfill(8)}.png', nrow=8, normalize=True, range=(-1, 1), ) if i % 2000 == 0: ckpt_name = f'checkpoint/{str(i).zfill(8)}.pt' # remove the previous checkpoint shutil.rmtree('checkpoint') os.mkdir('checkpoint') torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, ckpt_name, ) current_ckpt = i if wandb and args.wandb: wandb.save(ckpt_name) i = i + 1 pbar.update()
generator.proj.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio / 2, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True, ) transform = transforms.Compose( [ transforms.RandomVerticalFlip(p=0.5 if args.vflip else 0), transforms.RandomHorizontalFlip(p=0.5 if args.hflip else 0), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch_size, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), num_workers=8, drop_last=True, ) if get_rank() == 0: validation.get_dataset_inception_features(loader, args.name, args.size) wandb.init(project=f"maua-stylegan", name="Cyphept Correct BCR", config=vars(args)) scaler = th.cuda.amp.GradScaler() train(args, loader, generator, discriminator, contrast_learner, augment_fn, g_optim, d_optim, scaler, g_ema, device)
def train(opt): lib.print_model_settings(locals().copy()) # train_transform = transforms.Compose([ # # transforms.RandomResizedCrop(input_size), # transforms.Resize((opt.imgH, opt.imgW)), # # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ]) # val_transform = transforms.Compose([ # transforms.Resize((opt.imgH, opt.imgW)), # # transforms.CenterCrop(input_size), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ]) AlignFontCollateObj = AlignFontCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = fontDataset(imgDir=opt.train_img_dir, annFile=opt.train_ann_file, transform=None, numClasses=opt.numClasses) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignFontCollateObj, pin_memory=True, drop_last=False) # numClasses = len(train_dataset.Idx2F) numClasses = np.unique(train_dataset.fontIdx).size train_loader = sample_data(train_loader) print('-' * 80) numTrainSamples = len(train_dataset) # valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt) valid_dataset = fontDataset(imgDir=opt.train_img_dir, annFile=opt.val_ann_file, transform=None, F2Idx=train_dataset.F2Idx, Idx2F=train_dataset.Idx2F, numClasses=opt.numClasses) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. sampler=data_sampler(valid_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignFontCollateObj, pin_memory=True, drop_last=False) numTestSamples = len(valid_dataset) print('numClasses', numClasses) print('numTrainSamples', numTrainSamples) print('numTestSamples', numTestSamples) vggFontModel = VGGFontModel(models.vgg19(pretrained=opt.preTrained), numClasses).to(device) for name, param in vggFontModel.classifier.named_parameters(): try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. print('Exception in weight init' + name) if 'weight' in name: param.data.fill_(1) continue if opt.optim == "sgd": print('SGD optimizer') optimizer = optim.SGD(vggFontModel.parameters(), lr=opt.lr, momentum=0.9) elif opt.optim == "adam": print('Adam optimizer') optimizer = optim.Adam(vggFontModel.parameters(), lr=opt.lr) #get schedulers scheduler = get_scheduler(optimizer, opt) criterion = torch.nn.CrossEntropyLoss() if opt.modelFolderFlag: if len( glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_vggFont.pth"))) > 0: opt.saved_font_model = glob.glob( os.path.join(opt.exp_dir, opt.exp_name, "iter_*_vggFont.pth"))[-1] ## Loading pre-trained files if opt.saved_font_model != '' and opt.saved_font_model != 'None': print(f'loading pretrained synth model from {opt.saved_font_model}') checkpoint = torch.load(opt.saved_font_model, map_location=lambda storage, loc: storage) vggFontModel.load_state_dict(checkpoint['vggFontModel']) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) # print('Model Initialization') # # print('Loaded checkpoint') if opt.distributed: vggFontModel = torch.nn.parallel.DistributedDataParallel( vggFontModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, find_unused_parameters=True) vggFontModel.train() # print('Loaded distributed') if opt.distributed: vggFontModel_module = vggFontModel.module else: vggFontModel_module = vggFontModel # print('Loading module') # loss averager loss_train = Averager() loss_val = Averager() train_acc = Averager() val_acc = Averager() train_acc_5 = Averager() val_acc_5 = Averager() """ final options """ with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_font_model != '' and opt.saved_font_model != 'None': try: start_iter = int(opt.saved_font_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass iteration = start_iter cntr = 0 # trainCorrect=0 # tCntr=0 while (True): # print(cntr) # train part start_time = time.time() if not opt.testFlag: image_input_tensors, labels_gt = next(train_loader) image_input_tensors = image_input_tensors.to(device) labels_gt = labels_gt.view(-1).to(device) preds = vggFontModel(image_input_tensors) loss = criterion(preds, labels_gt) vggFontModel.zero_grad() loss.backward() optimizer.step() # _, preds_max = preds.max(dim=1) # trainCorrect += (preds_max == labels_gt).sum() # tCntr+=preds_max.shape[0] acc1, acc5 = getNumCorrect(preds, labels_gt, topk=(1, min(numClasses, 5))) train_acc.addScalar(acc1, preds.shape[0]) train_acc_5.addScalar(acc5, preds.shape[0]) loss_train.add(loss) if opt.lr_policy != "None": scheduler.step() # print if get_rank() == 0: if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0 or opt.testFlag: # To see training progress, we also conduct validation when 'iteration == 0' #validation # iCntr=torch.tensor(0.0).to(device) # valCorrect=torch.tensor(0.0).to(device) vggFontModel.eval() print('Inside val', iteration) for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader): # print('vCntr--',vCntr) if opt.debugFlag and vCntr > 2: break with torch.no_grad(): image_input_tensors = image_input_tensors.to(device) labels_gt = labels_gt.view(-1).to(device) preds = vggFontModel(image_input_tensors) loss = criterion(preds, labels_gt) loss_val.add(loss) # _, preds_max = preds.max(dim=1) # valCorrect += (preds_max == labels_gt).sum() # iCntr+=preds_max.shape[0] acc1, acc5 = getNumCorrect(preds, labels_gt, topk=(1, min(numClasses, 5))) val_acc.addScalar(acc1, preds.shape[0]) val_acc_5.addScalar(acc5, preds.shape[0]) vggFontModel.train() elapsed_time = time.time() - start_time #DO HERE with open( os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'), 'a') as log: # print('COUNT-------',val_acc_5.n_count) # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train loss: {loss_train.val():0.5f}, Val loss: {loss_val.val():0.5f}, \ Train Top-1 Acc: {train_acc.val()*100:0.5f}, Train Top-5 Acc: {train_acc_5.val()*100:0.5f}, \ Val Top-1 Acc: {val_acc.val()*100:0.5f}, Val Top-5 Acc: {val_acc_5.val()*100:0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir, 'Train-Loss'), loss_train.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Val-Loss'), loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-1-Acc'), train_acc.val() * 100) lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-5-Acc'), train_acc_5.val() * 100) lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-1-Acc'), val_acc.val() * 100) lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-5-Acc'), val_acc_5.val() * 100) print(loss_log) log.write(loss_log + "\n") loss_train.reset() loss_val.reset() train_acc.reset() val_acc.reset() train_acc_5.reset() val_acc_5.reset() # trainCorrect=0 # tCntr=0 lib.plot.flush() # save model per 30000 iter. if (iteration) % 15000 == 0: torch.save( { 'vggFontModel': vggFontModel_module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, os.path.join(opt.exp_dir, opt.exp_name, 'iter_' + str(iteration + 1) + '_vggFont.pth')) lib.plot.tick() if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr += 1
def train(args, loader, generator, discriminator, contrast_learner, g_optim, d_optim, g_ema): if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner loader = sample_data(loader) sample_z = th.randn(args.n_sample, args.latent_size, device=device) mse = th.nn.MSELoss() mean_path_length = 0 ada_augment = th.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 fids = [] pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break loss_dict = { "Generator": th.tensor(0, device=device).float(), "Discriminator": th.tensor(0, device=device).float(), "Real Score": th.tensor(0, device=device).float(), "Fake Score": th.tensor(0, device=device).float(), "Contrastive": th.tensor(0, device=device).float(), "Consistency": th.tensor(0, device=device).float(), "R1 Penalty": th.tensor(0, device=device).float(), "Path Length Regularization": th.tensor(0, device=device).float(), "Augment": th.tensor(0, device=device).float(), "Rt": th.tensor(0, device=device).float(), } requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() for _ in range(args.num_accumulate): real_img_og = next(loader).to(device) noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img_og, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img_og, ada_aug_p) real_img, _ = augment(real_img_og, ada_aug_p) else: fake_img = fake_img_og real_img = real_img_og fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) logistic_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["Discriminator"] += logistic_loss.detach() loss_dict["Real Score"] += real_pred.mean().detach() loss_dict["Fake Score"] += fake_pred.mean().detach() d_loss = logistic_loss if args.contrastive > 0: contrast_learner(fake_img_og, fake_img, accumulate=True) contrast_learner(real_img_og, real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["Contrastive"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: consistency_loss = mse( real_pred, discriminator(real_img_og)) + mse( fake_pred, discriminator(fake_img_og)) loss_dict["Consistency"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate d_loss.backward() d_optim.step() if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() for _ in range(args.num_accumulate): real_img = next(loader).to(device) real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_penalty(real_img, real_pred, args) loss_dict["R1 Penalty"] += r1_loss.detach().squeeze() r1_loss = args.r1 * args.d_reg_every * r1_loss / args.num_accumulate r1_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment += th.tensor( (th.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred loss_dict["Rt"] = th.tensor(r_t_stat, device=device).float() if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) loss_dict["Augment"] = th.tensor(ada_aug_p, device=device).float() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() for _ in range(args.num_accumulate): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_non_saturating_loss(fake_pred) loss_dict["Generator"] += g_loss.detach() g_loss /= args.num_accumulate g_loss.backward() g_optim.step() if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() for _ in range(args.num_accumulate): path_loss, mean_path_length = g_path_length_regularization( generator, mean_path_length, args) loss_dict["Path Length Regularization"] += path_loss.detach() path_loss = args.path_regularize * args.g_reg_every * path_loss / args.num_accumulate path_loss.backward() g_optim.step() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) log_dict = { k: v.mean().item() / args.num_accumulate for k, v in loss_reduced.items() if v != 0 } if get_rank() == 0: if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log( G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log( G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log( G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log( D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log( D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log( D_norms).max() if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema( [sample_z[sub:sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) log_dict["Generated Images EMA"] = [ wandb.Image(grid, caption=f"Step {i}") ] if i % args.eval_every == 0: fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] fids.append(fid) density = fid_dict["Density"] coverage = fid_dict["Coverage"] ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) log_dict["Evaluation/FID"] = fid log_dict["Sweep/FID_smooth"] = gaussian_filter( np.array(fids), [5])[-1] log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) description = ( f"FID: {fid:.4f} PPL: {ppl:.4f} Dens: {density:.4f} Cov: {coverage:.4f} " + f"G: {log_dict['Generator']:.4f} D: {log_dict['Discriminator']:.4f}" ) if "Augment" in log_dict: description += f" Aug: {log_dict['Augment']:.4f}" # Rt: {log_dict['Rt']:.4f}" if "R1 Penalty" in log_dict: description += f" R1: {log_dict['R1 Penalty']:.4f}" if "Path Length Regularization" in log_dict: description += f" Path: {log_dict['Path Length Regularization']:.4f}" pbar.set_description(description) if i % args.checkpoint_every == 0: check_name = "-".join([ args.name, args.runname, wandb.run.dir.split("/")[-1].split("-")[-1], int(fid), args.size, str(i).zfill(6), ]) th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{check_name}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet)) args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet args.n_sample = args.n_sample_per_sheet * args.n_sheets sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device) if args.n_sample > args.n_sample_per_class * args.n_classes: sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device) sample_y = torch.cat([sample_y, sample_y1], 0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break # Train Discriminator requires_grad(generator, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img, real_labels = next(loader) real_img, real_labels = real_img.to(device), real_labels.to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img, fake_labels) real_pred = discriminator(real_img_aug, real_labels) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug, real_labels) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img, fake_labels) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, latents = generator(noise, fake_labels, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update G_ema # G_ema = G * (1-ema_beta) + G_ema * ema_beta ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if i % args.log_every == 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f};\n" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() for sheet_index in range(args.n_sheets): sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample, _ = g_ema([sample_z_sheet], sample_y_sheet) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"), nrow=args.n_sample_per_class, normalize=True, value_range=(-1, 1), ) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device, n_classes=args.n_classes, ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() text_dataset = text_gen(opt) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), pin_memory=True, drop_last=True) opt.num_class = len(converter.character) c_code_size = opt.latent cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size) ocrModel = ModelV1(opt) genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size) accumulate(g_ema, genModel, 0) # uCriterion = torch.nn.MSELoss() # sCriterion = torch.nn.MSELoss() # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': # ocrCriterion = torch.nn.L1Loss() # else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: print('Not implemented error') sys.exit() # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 cEncoder= torch.nn.DataParallel(cEncoder).to(device) cEncoder.train() genModel = torch.nn.DataParallel(genModel).to(device) g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() disEncModel = torch.nn.DataParallel(disEncModel).to(device) disEncModel.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) if opt.ocrFixed: if opt.Transformation == 'TPS': ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() else: ocrModel.train() g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( list(genModel.parameters())+list(cEncoder.parameters()), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disEncModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ocr_optimizer = optim.Adam( ocrModel.parameters(), lr=opt.lr, betas=(0.9, 0.99), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) # if opt.saved_gen_model !='' and opt.saved_gen_model !='None': # print(f'loading pretrained gen model from {opt.saved_gen_model}') # checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) # genModel.module.load_state_dict(checkpoint['g']) # g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disEncModel.load_state_dict(checkpoint['disEncModel']) ocrModel.load_state_dict(checkpoint['ocrModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_unsup = Averager() loss_avg_sup = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() loss_avg_ocr_sup = Averager() loss_avg_ocr_unsup = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 # loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 epsilon = 10e-50 # sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() ocr_scheduler.step() image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size].detach() real_image_tensors = image_input_tensors[opt.batch_size:].detach() labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, False) requires_grad(genModel, False) requires_grad(disEncModel, True) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) # #unsupervised code prediction on generated image # u_pred_code = disEncModel(fake_img, mode='enc') # uCost = uCriterion(u_pred_code, z_code) # #supervised code prediction on gt image # s_pred_code = disEncModel(gt_image_tensors, mode='enc') # sCost = uCriterion(s_pred_code, gt_phoc_tensors) #Domain discriminator fake_pred = disEncModel(fake_img) real_pred = disEncModel(real_image_tensors) disCost = d_logistic_loss(real_pred, fake_pred) # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost loss_avg_dis.add(disCost) # loss_avg_sup.add(opt.beta*sCost) # loss_avg_unsup.add(opt.gamma_e * uCost) disEncModel.zero_grad() disCost.backward() dis_optimizer.step() d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: real_image_tensors.requires_grad = True real_pred = disEncModel(real_image_tensors) r1_loss = d_r1_loss(real_pred, real_image_tensors) disEncModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() log_r1_val.add(r1_loss) # Recognizer update if not opt.ocrFixed and not opt.zAlone: requires_grad(disEncModel, False) requires_grad(ocrModel, True) if 'CTC' in opt.Prediction: preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt) else: print("Not implemented error") sys.exit() ocrModel.zero_grad() ocrCost.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr_sup.add(ocrCost) else: loss_avg_ocr_sup.add(torch.tensor(0.0)) # [Word Generator] update # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:opt.batch_size] # real_image_tensors = image_input_tensors[opt.batch_size:] # labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, True) requires_grad(genModel, True) requires_grad(disEncModel, False) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) fake_pred = disEncModel(fake_img) disGenCost = g_nonsaturating_loss(fake_pred) if opt.zAlone: ocrCost = torch.tensor(0.0) else: #Compute OCR prediction (Reconstruction of content) # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device) if 'CTC' in opt.Prediction: preds_recon = ocrModel(fake_img, text_z_c, is_train=False) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c) else: print("Not implemented error") sys.exit() genModel.zero_grad() cEncoder.zero_grad() gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0] loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0] loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) if opt.grad_balance: gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0] a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR)) if a is None: print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) if a>1000 or a<0.0001: print(a) ocrCost = a.detach() * ocrCost gen_enc_cost = disGenCost + ocrCost gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0] loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) with torch.no_grad(): gen_enc_cost.backward() else: gen_enc_cost.backward() loss_avg_gen.add(disGenCost) loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost) optimizer.step() g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink) # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:path_batch_size] # labels_gt = labels[:path_batch_size] text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length) # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length) decay = 0.01 path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - mean_path_length_orig).pow(2).mean() mean_path_length = mean_path_length_orig.detach().item() genModel.zero_grad() cEncoder.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) #commented above for multi-gpu , non-distributed setting mean_path_length_avg = mean_path_length accumulate(g_ema, genModel, accum) log_avg_path_loss_val.add(path_loss) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #generate paired content with similar style labels_z_c_1 = iter(text_loader).next() labels_z_c_2 = iter(text_loader).next() text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length) text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length) z_c_code_1 = cEncoder(text_z_c_1) z_c_code_2 = cEncoder(text_z_c_2) style_c1_s1 = [] style_c2_s1 = [] style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s1.append(style_s1[0]*z_c_code_1) style_c2_s1.append(style_s1[0]*z_c_code_2) if len(style_s1)>1: style_c1_s1.append(style_s1[1]*z_c_code_1) style_c2_s1.append(style_s1[1]*z_c_code_2) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s2 = [] style_c1_s2.append(noise_style[0]*z_c_code_1) if len(noise_style)>1: style_c1_s2.append(noise_style[1]*z_c_code_1) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style_c1_s1[0][:,:opt.latent]) if len(style_c1_s1)>1: newstyle.append(style_c1_s1[1][:,:opt.latent]) style_c1_s1 = newstyle style_c2_s1 = newstyle style_c1_s2 = newstyle fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent) fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent) if not opt.zAlone: #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(gt_image_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt = converter.decode(preds_index.data, preds_size.data) else: print("Not implemented error") sys.exit() else: preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0] preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(opt.batch_size): try: save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')) if not opt.zAlone: save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')) save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png')) if trImgCntr<gt_image_tensors.shape[0]: save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) print(loss_log) loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_ocr_unsup.reset() loss_avg_ocr_sup.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ 'cEncoder':cEncoder.state_dict(), 'genModel':genModel.state_dict(), 'g_ema':g_ema.state_dict(), 'ocrModel':ocrModel.state_dict(), 'disEncModel':disEncModel.state_dict(), 'optimizer':optimizer.state_dict(), 'ocr_optimizer':ocr_optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(args, loader_src, loader_norm, generator, discriminator, ExpertModel, g_optim, d_optim, g_ema, device): # Save Path date = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') ImgSavePath = 'sample/{}'.format(date) CheckpointSavePath = 'checkpoint/{}'.format(date) if not os.path.exists(ImgSavePath): os.makedirs(ImgSavePath) if not os.path.exists(CheckpointSavePath): os.makedirs(CheckpointSavePath) shutil.copy('./train.py', './{}/train.py'.format(CheckpointSavePath)) shutil.copy('./model.py', './{}/model.py'.format(CheckpointSavePath)) loader_src = sample_data(loader_src) loader_norm = sample_data(loader_norm) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader_src) # source set tgt_img = next(loader_norm) # normal set real_img = real_img.to(device) tgt_img = tgt_img.to(device) #################################### Train discrimiantor #################################### requires_grad(generator, False) requires_grad(discriminator, True) Profile_Fea, Profile_Map = ExpertModel( TrainingSize_Select(real_img, device, args), args) Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map) Front_Fea, Front_Map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) Front_Syn_Img, _ = generator(Front_Fea, Front_Map) Profile_Syn_Pred = discriminator(Profile_Syn_Img) Front_Syn_Pred = discriminator(Front_Syn_Img) Real_Pred = discriminator(tgt_img) d_loss = (d_logistic_loss(Real_Pred, Profile_Syn_Pred) + d_logistic_loss(Real_Pred, Front_Syn_Pred)) / 2 loss_dict["d"] = d_loss loss_dict["real_score"] = Real_Pred.mean() loss_dict["profile_fake_score"] = Profile_Syn_Pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss #################################### Train generator #################################### requires_grad(generator, True) requires_grad(discriminator, False) Front_Fea, Front_Map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) Front_Syn_Img, _ = generator(Front_Fea, Front_Map) Front_Syn_Pred = discriminator(Front_Syn_Img) Front_Syn_Fea, _ = ExpertModel( TrainingSize_Select(Front_Syn_Img, device, args), args) Profile_Fea, Profile_Map = ExpertModel( TrainingSize_Select(real_img, device, args), args) Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map) Profile_Syn_Pred = discriminator(Profile_Syn_Img) Profile_Syn_Fea, _ = ExpertModel( TrainingSize_Select(Profile_Syn_Img, device, args), args) adv_g_loss = (g_nonsaturating_loss(Profile_Syn_Pred) + g_nonsaturating_loss(Front_Syn_Pred)) / 2 fea_loss = (feature_loss(Profile_Syn_Fea[0], Profile_Fea[0]) + feature_loss(Front_Syn_Fea[0], Front_Fea[0])) / 2 sym_loss = (SymLoss(Front_Syn_Img) + SymLoss(Profile_Syn_Img)) / 2 L1_loss = L1Loss(Front_Syn_Img, tgt_img) g_loss = args.lambda_adv * adv_g_loss + args.lambda_fea * fea_loss + args.lambda_sym * sym_loss + args.lambda_l1 * L1_loss loss_dict["g"] = g_loss loss_dict["adv_g_loss"] = args.lambda_adv * adv_g_loss loss_dict["fea_loss"] = args.lambda_fea * fea_loss loss_dict["symmetry_loss"] = args.lambda_sym * sym_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise, noise_map = ExpertModel( TrainingSize_Select(real_img, device, args), args) fake_img, latents = generator(noise, noise_map, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() fea_loss_val = loss_reduced["fea_loss"].mean().item() sym_loss_val = loss_reduced["symmetry_loss"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() profile_fake_score_val = loss_reduced["profile_fake_score"].mean( ).item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g_total: {g_loss_val:.4f}; fea: {fea_loss_val:.4f}; sym: {sym_loss_val:.4f}; r1: {r1_val:.4f};" f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" )) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Profile Score": profile_fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() pro_fea, pro_map = ExpertModel( TrainingSize_Select(real_img, device, args), args) pro_syn, _ = g_ema(pro_fea, pro_map) tgt_fea, tgt_map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) tgt_syn, _ = g_ema(tgt_fea, tgt_map) result = torch.cat([real_img, pro_syn, tgt_img, tgt_syn], 2) utils.save_image( result, f"{ImgSavePath}/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 100 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"{CheckpointSavePath}/{str(i).zfill(6)}.pt", )
def train(args, loader, encoder, generator, discriminator, discriminator_z, g1, vggnet, pwcnet, e_optim, d_optim, dz_optim, g1_optim, e_ema, e_tf, g1_ema, device): mmd_eval = functools.partial(mix_rbf_mmd2, sigma_list=[2.0, 5.0, 10.0, 20.0, 40.0, 80.0]) loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module g1_module = g1.module if args.train_latent_mlp else None else: e_module = encoder d_module = discriminator g_module = generator g1_module = g1 if args.train_latent_mlp else None accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) # sample_x = accumulate_batches(loader, args.n_sample).to(device) sample_x = load_real_samples(args, loader) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) batch = real_img.shape[0] # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) kld_z = torch.tensor(0., device=device) mmd_z = torch.tensor(0., device=device) gan_z = torch.tensor(0., device=device) etf_z = torch.tensor(0., device=device) latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img)**2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat)**2) if args.lambda_kld_z > 0: z_mean = latent_real.view(batch, -1) kld_z = -0.5 * torch.sum(1. + logvar - z_mean.pow(2) - logvar.exp()) / batch # print(kld_z) if args.lambda_mmd_z > 0: z_real = torch.randn(batch, args.latent_full, device=device) mmd_z = mmd_eval(latent_real, z_real) # print(mmd_z) if args.lambda_gan_z > 0: fake_pred = discriminator_z(latent_real) gan_z = g_nonsaturating_loss(fake_pred) # print(gan_z) if args.use_latent_teacher_forcing and args.lambda_etf > 0: w_tf, _ = e_tf(real_img) if args.train_latent_mlp: w_pred = g1(latent_real) else: w_pred = generator.get_latent(latent_real) etf_z = torch.mean((w_tf - w_pred)**2) # print(etf_z) if args.train_on_fake and args.lambda_rec > 0: z_real = torch.randn(args.batch, args.latent_full, device=device) if args.train_latent_mlp: fake_img, _ = generator([g1(z_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([z_real], input_is_latent=False, return_latents=False) # fake_img, _ = generator([z_real], input_is_latent=False, return_latents=True) z_fake, z_logvar = encoder(fake_img) if args.reparameterization: z_fake = reparameterize(z_fake, z_logvar) rec_loss = torch.mean((z_real - z_fake)**2) loss_dict["rec"] = rec_loss # print(rec_loss) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_kld_z * kld_z + args.lambda_mmd_z * mmd_z + args.lambda_gan_z * gan_z + args.lambda_etf * etf_z + rec_loss * args.lambda_rec loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.train_latent_mlp and g1 is not None: g1.zero_grad() encoder.zero_grad() e_loss.backward() e_optim.step() if args.train_latent_mlp and g1_optim is not None: g1_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # z_real = torch.randn(args.batch, args.latent_full, device=device) # fake_img, w_real = generator([z_real], input_is_latent=False, return_latents=True) # z_fake, logvar = encoder(fake_img) # if args.reparameterization: # z_fake = reparameterize(z_fake, logvar) # rec_loss = torch.mean((z_real - z_fake) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 if e_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred, logvar = encoder(real_img) if args.reparameterization: real_pred = reparameterize(real_pred, logvar) r1_loss_e = d_r1_loss(real_pred, real_img) encoder.zero_grad() (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() e_optim.step() loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) if args.train_latent_mlp: accumulate(g1_ema, g1_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() z_real = torch.randn(batch, args.latent_full, device=device) fake_pred = discriminator_z(latent_real.detach()) real_pred = discriminator_z(z_real) d_loss_z = d_logistic_loss(real_pred, fake_pred) discriminator_z.zero_grad() d_loss_z.backward() dz_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_x, _ = generator([g1_ema(latent_x)], input_is_latent=True, return_latents=False) else: fake_x, _ = generator([latent_x], input_is_latent=False, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_eval(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_img, _ = generator([g1_ema(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) start_iter = args.start_iter // get_world_size() // args.batch pbar = range(args.iter // get_world_size() // args.batch) if get_rank() == 0: pbar = tqdm(pbar, initial=start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 seg_loss = torch.tensor(0.0, device=device) r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg, seg_loss_val, shift_loss_val = 0, 0, 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_condition_img, sample_conditions, condition_img_color = random_condition_img( args.n_sample) if get_rank() == 0: os.makedirs(f'sample', exist_ok=True) os.makedirs(f'sample/{args.name}', exist_ok=True) os.makedirs(f'ckpts/{args.name}', exist_ok=True) if args.with_tensorboard: os.makedirs(f'tensorboard/{args.name}', exist_ok=True) writer = SummaryWriter(f'tensorboard/{args.name}') for idx in pbar: i = idx + start_iter if i > args.iter: print('Done!') break real_img, condition_img = next(loader) real_img = real_img.to(device) if args.condition_path is not None: condition_img = condition_img.to(device) else: condition_img = None requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _, _, _ = generator(noise, condition_img=condition_img) if args.with_rgbs: condition_img_encoder = F.interpolate(condition_img, size=args.resolution, mode='nearest') real_img = torch.cat((real_img, condition_img_encoder), dim=1) fake_pred, _ = discriminator(fake_img) real_pred, real_pred_feat = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred, _ = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _, _, parsing_feature = generator( noise, condition_img=condition_img) fake_pred, fake_pred_feat = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss loss_dict['seg'] = seg_loss loss_dict['shift_loss'] = seg_loss loss = g_loss generator.zero_grad() loss.backward() g_optim.step() requires_grad(generator, True) requires_grad(discriminator, False) g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) if args.condition_path is not None: condition_img = condition_img[range(path_batch_size)] condition_img.requires_grad = True fake_img, latents, _, _ = generator(noise, return_latents=True, condition_img=condition_img) path_loss, mean_path_length, path_lengths, isNaN = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.g_reg_every * args.path_regularize * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() if not isNaN: g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if args.condition_path is not None and (0 == i % args.g_reg_every): seg_loss_val = loss_reduced['seg'].mean().item() shift_loss_val = loss_reduced['shift_loss'].mean().item() if get_rank() == 0: pbar.set_description((f'mean path: {mean_path_length_avg:.4f}')) if args.with_tensorboard: writer.add_scalar('Loss/Generator', g_loss_val, i) writer.add_scalar('Loss/Discriminator', d_loss_val, i) writer.add_scalar('Loss/R1', r1_val, i) writer.add_scalar('Loss/Path Length', path_length_val, i) writer.add_scalar('Loss/mean path', mean_path_length_avg, i) if args.condition_path is not None: writer.add_scalar('Loss/seg_img', seg_loss_val, i) writer.add_scalar('Loss/shift_loss', shift_loss_val, i) steps = get_world_size() * args.batch * (1 + i) if steps % 100000 < get_world_size() * args.batch or ( steps < 1000 and steps % 500 == get_world_size() * args.batch): with torch.no_grad(): g_ema.eval() samples, featuresMaps, parsing_features = [], [], [] small_batch = args.n_sample // args.batch if 0 != args.n_sample % args.batch: small_batch += 1 # only condition change rows = int(args.n_sample**0.5) if args.condition_path is not None: sample_z = mixing_noise(rows, args.latent, args.mixing, device) sample_z = sample_z.unsqueeze(1).repeat( 1, rows, 1, 1).view(args.n_sample, sample_z.shape[1], sample_z.shape[2]) else: sample_z = mixing_noise(args.n_sample, args.latent, args.mixing, device) for k in range(small_batch): start, end = k * args.batch, (k + 1) * args.batch if k == small_batch - 1: end = sample_z.shape[0] if args.condition_path is not None: sample_condition_img_sub = sample_condition_img[ start:end] sample_condition_img_sub = random_affine( sample_condition_img_sub.clone(), Scale=0.0).to(device) else: sample_condition_img_sub = None sample, _, _, _ = g_ema( sample_z[start:end], condition_img=sample_condition_img_sub) samples.append(sample.cpu().detach()) samples = torch.cat(samples, dim=0) nrow = int(args.n_sample**0.5) c, h, w = samples.shape[-3:] samples = samples.reshape(nrow, nrow, c, h, w).transpose( 1, 0).reshape(-1, c, h, w) utils.save_image( samples, f'sample/{args.name}/{str(steps).zfill(6)}.png', nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if 0 == i: c, h, w = condition_img_color.shape[-3:] condition_img_color = condition_img_color.reshape( nrow, nrow, c, h, w).transpose(1, 0).reshape(-1, c, h, w) utils.save_image( condition_img_color, f'sample/{args.name}/seg_vis.png', nrow=nrow, normalize=True, range=(-1, 1), ) if (steps + get_world_size() * args.batch) % 100000 < get_world_size( ) * args.batch and steps != args.start_iter: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), # 'g_optim': g_optim.state_dict(), # 'd_optim': d_optim.state_dict(), }, f'ckpts/{args.name}/{str(steps).zfill(6)}.pt', )
def train(args, loader, generator, discriminator, contrast_learner, augment, g_optim, d_optim, scaler, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = th.zeros(size=(1,), device=device) g_loss_val = 0 path_loss = th.zeros(size=(1,), device=device) path_lengths = th.zeros(size=(1,), device=device) loss_dict = {} mse = th.nn.MSELoss() if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner sample_z = th.randn(args.n_sample, args.latent_size, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() loss_dict["d"], loss_dict["real_score"], loss_dict["fake_score"] = 0, 0, 0 loss_dict["cl_reg"], loss_dict["bc_reg"] = ( th.tensor(0, device=device).float(), th.tensor(0, device=device).float(), ) for _ in range(args.num_accumulate): # sample = [] # for _ in range(0, len(sample_z), args.batch_size): # subsample = next(loader) # sample.append(subsample) # sample = th.cat(sample) # utils.save_image(sample, "reals-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "reals-augment.png", nrow=10, normalize=True) real_img = next(loader) real_img = real_img.to(device) # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_D: fake_pred = discriminator(augment(fake_img)) real_pred = discriminator(augment(real_img)) else: fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) # logistic loss real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) d_loss = real_loss.mean() + fake_loss.mean() loss_dict["d"] += d_loss.detach() loss_dict["real_score"] += real_pred.mean().detach() loss_dict["fake_score"] += fake_pred.mean().detach() if i > 10000 or i == 0: if args.contrastive > 0: contrast_learner(fake_img.clone().detach(), accumulate=True) contrast_learner(real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["cl_reg"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: aug_fake_pred = discriminator(augment(fake_img.clone().detach())) aug_real_pred = discriminator(augment(real_img)) consistency_loss = mse(real_pred, aug_real_pred) + mse(fake_pred, aug_fake_pred) loss_dict["bc_reg"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate # scaler.scale(d_loss).backward() d_loss.backward() # scaler.step(d_optim) d_optim.step() # R1 regularization if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() loss_dict["r1"] = 0 for _ in range(args.num_accumulate): real_img = next(loader) real_img = real_img.to(device) real_img.requires_grad = True # with th.cuda.amp.autocast(): # if args.augment_D: # real_pred = discriminator( # augment(real_img) # ) # RuntimeError: derivative for grid_sampler_2d_backward is not implemented :( # else: real_pred = discriminator(real_img) real_pred_sum = real_pred.sum() (grad_real,) = th.autograd.grad(outputs=real_pred_sum, inputs=real_img, create_graph=True) # (grad_real,) = th.autograd.grad(outputs=scaler.scale(real_pred_sum), inputs=real_img, create_graph=True) # grad_real = grad_real * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): r1_loss = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() weighted_r1_loss = args.r1 / 2.0 * r1_loss * args.d_reg_every + 0 * real_pred[0] loss_dict["r1"] += r1_loss.detach() weighted_r1_loss /= args.num_accumulate # scaler.scale(weighted_r1_loss).backward() weighted_r1_loss.backward() # scaler.step(d_optim) d_optim.step() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() loss_dict["g"] = 0 for _ in range(args.num_accumulate): # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_G: fake_img = augment(fake_img) fake_pred = discriminator(fake_img) # non-saturating loss g_loss = F.softplus(-fake_pred).mean() loss_dict["g"] += g_loss.detach() g_loss /= args.num_accumulate # scaler.scale(g_loss).backward() g_loss.backward() # scaler.step(g_optim) g_optim.step() # path length regularization if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() loss_dict["path"], loss_dict["path_length"] = 0, 0 for _ in range(args.num_accumulate): path_batch_size = max(1, args.batch_size // args.path_batch_shrink) # with th.cuda.amp.autocast(): noise = make_noise(path_batch_size, args.latent_size, args.mixing_prob, device) fake_img, latents = generator(noise, return_latents=True) img_noise = th.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) noisy_img_sum = (fake_img * img_noise).sum() (grad,) = th.autograd.grad(outputs=noisy_img_sum, inputs=latents, create_graph=True) # (grad,) = th.autograd.grad(outputs=scaler.scale(noisy_img_sum), inputs=latents, create_graph=True) # grad = grad * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): path_lengths = th.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + 0.01 * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - path_mean).pow(2).mean() mean_path_length = path_mean.detach() loss_dict["path"] += path_loss.detach() loss_dict["path_length"] += path_lengths.mean().detach() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss /= args.num_accumulate # scaler.scale(weighted_path_loss).backward() weighted_path_loss.backward() # scaler.step(g_optim) g_optim.step() # scaler.update() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() / args.num_accumulate g_loss_val = loss_reduced["g"].mean().item() / args.num_accumulate cl_reg_val = loss_reduced["cl_reg"].mean().item() / args.num_accumulate bc_reg_val = loss_reduced["bc_reg"].mean().item() / args.num_accumulate r1_val = loss_reduced["r1"].mean().item() / args.num_accumulate path_loss_val = loss_reduced["path"].mean().item() / args.num_accumulate real_score_val = loss_reduced["real_score"].mean().item() / args.num_accumulate fake_score_val = loss_reduced["fake_score"].mean().item() / args.num_accumulate path_length_val = loss_reduced["path_length"].mean().item() / args.num_accumulate if get_rank() == 0: log_dict = { "Generator": g_loss_val, "Discriminator": d_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Contrastive": cl_reg_val, "Consistency": bc_reg_val, } if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log(G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log(G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log(D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log(D_norms).max() if args.r1 > 0 and i % args.d_reg_every == 0: log_dict["R1"] = r1_val if args.path_regularize > 0 and i % args.g_reg_every == 0: log_dict["Path Length Regularization"] = path_loss_val log_dict["Mean Path Length"] = mean_path_length log_dict["Path Length"] = path_length_val if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema([sample_z[sub : sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) # utils.save_image(sample, "fakes-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "fakes-augment.png", nrow=10, normalize=True) # exit() log_dict["Generated Images EMA"] = [wandb.Image(grid, caption=f"Step {i}")] if i % args.eval_every == 0: start_time = time.time() pbar.set_description((f"Calculating FID...")) fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] density = fid_dict["Density"] coverage = fid_dict["Coverage"] pbar.set_description((f"Calculating PPL...")) ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) pbar.set_description( ( f"FID: {fid:.4f}; Density: {density:.4f}; Coverage: {coverage:.4f}; PPL: {ppl:.4f} in {time.time() - start_time:.1f}s" ) ) log_dict["Evaluation/FID"] = fid log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) if i % args.checkpoint_every == 0: th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{args.name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}-{int(fid)}-{int(ppl)}-{str(i).zfill(6)}.pt", )
def train( args, loader, encoder, generator, discriminator, discriminator3d, # video disctiminator posterior, prior, factor, # a learnable matrix vggnet, e_optim, d_optim, dv_optim, q_optim, # q for posterior p_optim, # p for prior f_optim, # f for factor e_ema, device ): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = {"d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device),} if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 latent_full = args.latent_full factor_dim_full = args.factor_dim_full if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_x = accumulate_batches(loader, args.n_sample).to(device) utils.save_image( sample_x.view(-1, *list(sample_x.shape)[2:]), os.path.join(args.log_dir, 'sample', f"real-img.png"), nrow=sample_x.shape[1], normalize=True, value_range=(-1, 1), ) util.save_video( sample_x[0], os.path.join(args.log_dir, 'sample', f"real-vid.mp4") ) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode if args.no_update_encoder: encoder = e_ema if e_ema is not None else encoder requires_grad(encoder, False) encoder.eval() from models.networks_3d import GANLoss criterionGAN = GANLoss() # criterionL1 = nn.L1Loss() # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break data = next(loader) real_seq = data['frames'] real_seq = real_seq.to(device) # [N, T, C, H, W] shape = list(real_seq.shape) N, T = shape[:2] # Train Encoder with frame-level objectives if args.toggle_grads: if not args.no_update_encoder: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = vid_loss = l1y_loss = torch.tensor(0., device=device) # TODO: real_seq -> encoder -> posterior -> generator -> fake_seq # f: [N, latent_full]; y: [N, T, D] fake_img, fake_seq, y_post = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior, i, ret_y=True) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # # single head: f_post [N, latent_full]; y_post [N, T, D] # # multi head: f_post [N, n_latent, latent]; y_post [N, T, n_latent, d] # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # TODO: sample frames real_img = real_seq.view(N*T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) # TODO: do we always put pix and vgg loss for all frames? if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img) ** 2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat) ** 2) # Train Encoder with video-level objectives # TODO: video adversarial loss if args.lambda_vid > 0: fake_pred = discriminator3d(flip_video(fake_seq.transpose(1, 2))) vid_loss = criterionGAN(fake_pred, True) if args.lambda_l1y > 0: # l1y_loss = criterionL1(y_post) l1y_loss = torch.mean(torch.abs(y_post)) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_vid * vid_loss + args.lambda_l1y * l1y_loss loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if not args.no_update_encoder: encoder.zero_grad() posterior.zero_grad() e_loss.backward() q_optim.step() if not args.no_update_encoder: e_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # latent_pred = encoder(fake_img) # if latent_pred.ndim < 3: # latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1) # rec_loss = torch.mean((latent_fake - latent_pred) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 # if e_regularize: # # why not regularize on augmented real? # real_img.requires_grad = True # real_pred = encoder(real_img) # r1_loss_e = d_r1_loss(real_pred, real_img) # encoder.zero_grad() # (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() # e_optim.step() # loss_dict["r1_e"] = r1_loss_e if not args.no_update_encoder: if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) fake_img, fake_seq = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'coef': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post_hat = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # y_post = torch.mm(z_post_hat.view(N*T, -1), factor.weight).view(N, T, -1) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if not args.no_update_discriminator: if args.lambda_adv > 0: if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) # Train video discriminator if args.lambda_vid > 0: pred_real = discriminator3d(flip_video(real_seq.transpose(1, 2))) pred_fake = discriminator3d(flip_video(fake_seq.transpose(1, 2))) dv_loss_real = criterionGAN(pred_real, True) dv_loss_fake = criterionGAN(pred_fake, False) dv_loss = 0.5 * (dv_loss_real + dv_loss_fake) d_loss = d_loss + dv_loss loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() if args.lambda_adv > 0: discriminator.zero_grad() if args.lambda_vid > 0: discriminator3d.zero_grad() d_loss.backward() if args.lambda_adv > 0: d_optim.step() if args.lambda_vid > 0: dv_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, } ) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() posterior.eval() # N = sample_x.shape[0] fake_img, fake_seq = reconstruct_sequence(args, sample_x, e_eval, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # x_lat = encoder(sample_x.view(-1, *shape[2:])) # f_post, y_post = posterior(x_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_recon.png"), nrow=T, normalize=True, value_range=(-1, 1), ) util.save_video( fake_seq[random.randint(0, args.n_sample-1)], os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-vid_recon.mp4") ) fake_img, fake_seq = swap_sequence(args, sample_x, e_eval, generator, factor, posterior) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_swap.png"), nrow=T, normalize=True, value_range=(-1, 1), ) e_eval.train() posterior.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if not args.debug and i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator none_g_grads = set() test_in = torch.randn(1, args.latent, device=device) fake, latent = g_module([test_in], return_latents=True) path = g_path_regularize(fake, latent, 0) path[0].backward() for n, p in generator.named_parameters(): if p.grad is None: none_g_grads.add(n) test_in = torch.randn(1, 3, args.size, args.size, requires_grad=True, device=device) pred = d_module(test_in) r1_loss = d_r1_loss(pred, test_in) r1_loss.backward() none_d_grads = set() for n, p in discriminator.named_parameters(): if p.grad is None: none_d_grads.add(n) sample_z = torch.randn(2 * 2, args.latent, device=device) for i in pbar: real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() set_grad_none(discriminator, none_d_grads) d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator.proj, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) noise_proj_loss = sum([(generator.proj(noise_i) - noise_i).abs().sum() for noise_i in noise]) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) print(noise_proj_loss.item()) loss_dict['g'] = g_loss generator.zero_grad() (g_loss + noise_proj_loss).backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise = mixing_noise( args.batch // args.path_batch_shrink, args.latent, args.mixing, device ) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() set_grad_none(g_module, none_g_grads) g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_description( ( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' ) ) if wandb and args.wandb: wandb.log( { 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, } ) if i % 10000 == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, f'checkpoint/{str(i).zfill(6)}.pt', ) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=2, normalize=True, range=(-1, 1), )
def train(args, loader, generator, encoder, discriminator, vggnet, g_optim, e_optim, d_optim, g_ema, e_ema, device): kwargs_d = {'detach_aux': False} if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] else: inception = real_mean = real_cov = None mean_latent = None pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) real_pred = discriminator(real_img) fake_pred = discriminator(fake_img) rec_pred = discriminator(rec_img) d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() d_loss_rec = F.softplus(rec_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_dict["rec_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_fake + d_loss_rec loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # # Train Encoder and Generator # requires_grad(generator, True) # requires_grad(encoder, True) # requires_grad(discriminator, False) # pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, _ = generator(noise) # latent_real, _ = encoder(real_img) # rec_img, _ = generator([latent_real], input_is_latent=True) # fake_pred = discriminator(fake_img) # rec_pred = discriminator(rec_img) # g_loss_fake = g_nonsaturating_loss(fake_pred) # g_loss_rec = g_nonsaturating_loss(rec_pred) # adv_loss = g_loss_fake + g_loss_rec # if args.lambda_pix > 0: # if args.pix_loss == 'l2': # pix_loss = torch.mean((rec_img - real_img) ** 2) # else: # pix_loss = F.l1_loss(rec_img, real_img) # if args.lambda_vgg > 0: # vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img)) ** 2) # e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv # loss_dict["e"] = e_loss # encoder.zero_grad() # generator.zero_grad() # e_loss.backward() # e_optim.step() # g_optim.step() # Train Encoder requires_grad(generator, False) requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator(rec_img) g_loss_rec = g_nonsaturating_loss(rec_pred) adv_loss = g_loss_rec if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss encoder.zero_grad() e_loss.backward() e_optim.step() # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) fake_pred = discriminator(fake_img) rec_pred = discriminator(rec_img) g_loss_fake = g_nonsaturating_loss(fake_pred) g_loss_rec = g_nonsaturating_loss(rec_pred) adv_loss = g_loss_fake + g_loss_rec if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) g_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() with torch.no_grad(): latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) pix_loss_val = pix_loss.mean().item() vgg_loss_val = vgg_loss.mean().item() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}: fid: {float(fid):.4f}\n") if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): # Fixed fake samples g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )