Exemple #1
0
def train(args, epoch, loader, model, optimizer, scheduler):
    torch.backends.cudnn.benchmark = True

    model.train()

    if get_rank() == 0:
        pbar = tqdm(loader, dynamic_ncols=True)

    else:
        pbar = loader

    for i, (img, annot) in enumerate(pbar):
        img = img.to('cuda')
        annot = annot.to('cuda')

        loss, _ = model(img, annot)
        loss_sum = loss['loss'] + args.aux_weight * loss['aux']
        model.zero_grad()
        loss_sum.backward()
        optimizer.step()
        scheduler.step()

        loss_dict = reduce_loss_dict(loss)
        loss = loss_dict['loss'].mean().item()
        aux_loss = loss_dict['aux'].mean().item()

        if get_rank() == 0:
            lr = optimizer.param_groups[0]['lr']

            pbar.set_description(
                f'epoch: {epoch + 1}; loss: {loss:.5f}; aux loss: {aux_loss:.5f}; lr: {lr:.5f}'
            )
Exemple #2
0
def valid(args, epoch, loader, model, show):
    torch.backends.cudnn.benchmark = False

    model.eval()

    if get_rank() == 0:
        pbar = tqdm(loader, dynamic_ncols=True)

    else:
        pbar = loader

    intersect_sum = None
    union_sum = None
    correct_sum = 0
    total_sum = 0

    for i, (img, annot) in enumerate(pbar):
        img = img.to('cuda')
        annot = annot.to('cuda')
        _, out = model(img)
        _, pred = out.max(1)

        if get_rank() == 0 and i % 10 == 0:
            result = show(img[0], annot[0], pred[0])
            result.save(f'sample/{str(epoch + 1).zfill(3)}-{str(i).zfill(4)}.png')

        pred = (annot > 0) * pred
        correct = (pred > 0) * (pred == annot)
        correct_sum += correct.sum().float().item()
        total_sum += (annot > 0).sum().float()

        for g, p, c in zip(annot, pred, correct):
            intersect, union = intersection_union(g, p, c, args.n_class)

            if intersect_sum is None:
                intersect_sum = intersect

            else:
                intersect_sum += intersect

            if union_sum is None:
                union_sum = union

            else:
                union_sum += union

        all_intersect = sum(all_gather(intersect_sum.to('cpu')))
        all_union = sum(all_gather(union_sum.to('cpu')))

        if get_rank() == 0:
            iou = all_intersect / (all_union + 1e-10)
            m_iou = iou.mean().item()

            pbar.set_description(
                f'acc: {correct_sum / total_sum:.5f}; mIoU: {m_iou:.5f}'
            )
Exemple #3
0
def valid(args, epoch, loader, dataset, model, device):
    if args.distributed:
        model = model.module

    torch.cuda.empty_cache()

    model.eval()

    pbar = tqdm(loader, dynamic_ncols=True)

    preds = {}

    for images, targets, ids in pbar:
        model.zero_grad()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        pred, _ = model(images.tensors, images.sizes)

        pred = [p.to('cpu') for p in pred]

        preds.update({id: p for id, p in zip(ids, pred)})

    preds = accumulate_predictions(preds)

    if get_rank() != 0:
        return

    evaluate(dataset, preds)
    return
Exemple #4
0
def valid(args, epoch, loader, dataset, model, device, logger=None):
    if args.distributed:
        model = model.module

    torch.cuda.empty_cache()

    model.eval()

    if get_rank() == 0:
        pbar = tqdm(enumerate(loader), total=len(loader), dynamic_ncols=True)
    else:
        pbar = enumerate(loader)

    preds = {}

    for idx, (images, targets, ids) in pbar:
        model.zero_grad()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        pred, _ = model(images.tensors, images.sizes)

        pred = [p.to('cpu') for p in pred]

        preds.update({id: p for id, p in zip(ids, pred)})

    preds = accumulate_predictions(preds)

    if get_rank() != 0:
        return

    evl_res = evaluate(dataset, preds)

    # writing log to tensorboard
    if logger:
        log_group_name = "validation"
        box_result = evl_res['bbox']
        logger.add_scalar(log_group_name + '/AP', box_result['AP'], epoch)
        logger.add_scalar(log_group_name + '/AP50', box_result['AP50'], epoch)
        logger.add_scalar(log_group_name + '/AP75', box_result['AP75'], epoch)
        logger.add_scalar(log_group_name + '/APl', box_result['APl'], epoch)
        logger.add_scalar(log_group_name + '/APm', box_result['APm'], epoch)
        logger.add_scalar(log_group_name + '/APs', box_result['APs'], epoch)

    return preds
Exemple #5
0
def train(args, epoch, loader, model, optimizer, device, logger=None):
    model.train()

    if get_rank() == 0:
        pbar = tqdm(enumerate(loader), total=len(loader), dynamic_ncols=True)
    else:
        pbar = enumerate(loader)

    for idx, (images, targets, _) in pbar:
        model.zero_grad()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        _, loss_dict = model(images, targets=targets)
        loss_cls = loss_dict['loss_cls'].mean()
        loss_box = loss_dict['loss_reg'].mean()
        loss_center = loss_dict['loss_centerness'].mean()

        loss = loss_cls + loss_box + loss_center
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        loss_reduced = reduce_loss_dict(loss_dict)
        loss_cls = loss_reduced['loss_cls'].mean().item()
        loss_box = loss_reduced['loss_reg'].mean().item()
        loss_center = loss_reduced['loss_centerness'].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (f'epoch: {epoch + 1}; cls: {loss_cls:.4f}; '
                 f'box: {loss_box:.4f}; center: {loss_center:.4f}'))

            # writing log to tensorboard
            if logger and idx % 10 == 0:
                totalStep = (epoch * len(loader) +
                             idx) * args.batch * args.n_gpu
                logger.add_scalar('training/loss_cls', loss_cls, totalStep)
                logger.add_scalar('training/loss_box', loss_box, totalStep)
                logger.add_scalar('training/loss_center', loss_center,
                                  totalStep)
                logger.add_scalar('training/loss_all',
                                  (loss_cls + loss_box + loss_center),
                                  totalStep)
Exemple #6
0
    def set_logger(self):
        if get_rank() == 0 and wandb is not None and self.args.wandb:
            wandb.init(project="stylegan 2")
        else:
            self.log_dir = '%s/%d' % (self.args.log_dir, self.args.manualSeed)
            os.makedirs(self.log_dir, exist_ok=True)
            self.summary = SummaryWriter(log_dir=self.log_dir)

        with tarfile.open(os.path.join(self.log_dir, 'code.tar.gz'),
                          "w:gz") as tar:
            for addfile in ['train.py', 'dataset.py', 'model.py']:
                tar.add(addfile)
        '''with open(os.path.join(self.log_dir, 'args.txt'), 'w') as f:
    def __init__(self,
                 img_root_path,
                 img_keys_path,
                 transform,
                 batch_size,
                 dist_mode: bool = False,
                 rank_seed: Optional[int] = None,
                 with_key: bool = False):
        self.img_root_path = img_root_path
        self.img_keys_path = img_keys_path
        self.transform = transform
        self.batch_size = batch_size
        if rank_seed is not None:
            self.rand = random.Random(rank_seed)
        else:
            self.rand = random.Random(time.time())

        self.img_keys_file_list = [
            os.path.join(self.img_keys_path, f)
            for f in os.listdir(self.img_keys_path) if not f.startswith('.')
        ]

        self.rand.shuffle(self.img_keys_file_list)
        self.rank = -1
        if dist_mode:
            rank_pic_size = int(
                math.ceil(
                    len(self.img_keys_file_list) / dist.get_world_size()))
            self.img_keys_file_list = self.img_keys_file_list[
                rank_pic_size * dist.get_rank():rank_pic_size *
                (dist.get_rank() + 1)]
            self.rank = dist.get_rank()
        self.num_examples = max((sum((1 for _ in open(f))) for f in self.img_keys_file_list[:10])) \
                            * len(self.img_keys_file_list)
        self.num_itertions = int(math.ceil(self.num_examples / batch_size))
        self.with_key = with_key
Exemple #8
0
def get_logit(dataloader, netD, device):
    data_iter = iter(dataloader)
    logit_list = np.zeros(len(dataloader.dataset))
    netD.eval()
    with torch.no_grad():
        if get_rank() == 0:
            data_iter = tqdm(data_iter)
        for data, idx in data_iter:
            real_data = data.to(device)
            idx = idx.to(device)
            logit_r = netD(real_data).view(-1)
            idx_all = concat_all_gather(idx)
            logit_r = concat_all_gather(logit_r)
            logit_list[idx_all.cpu().numpy()] = logit_r.detach().cpu().numpy()
    netD.train()
    return logit_list
Exemple #9
0
def accumulate_predictions(predictions):
    all_predictions = all_gather(predictions)

    if get_rank() != 0:
        return

    predictions = {}

    for p in all_predictions:
        predictions.update(p)

    ids = list(sorted(predictions.keys()))

    if len(ids) != ids[-1] + 1:
        print('Evaluation results is not contiguous')

    predictions = [predictions[i] for i in ids]

    return predictions
def save_predictions_to_images(dataset, predictions):
    #
    if get_rank() != 0:
        return

    for id, pred in enumerate(predictions):
        orig_id = dataset.id2img[id]

        if len(pred) == 0:
            continue

        img_meta = dataset.get_image_meta(id)
        width = img_meta['width']
        height = img_meta['height']
        pred = pred.resize((width, height))

        boxes = pred.bbox.tolist()
        scores = pred.get_field('scores').tolist()
        ids = pred.get_field('labels').tolist()

        img_name = img_meta['file_name']
        img_baseName = os.path.splitext(img_name)[0]
        #
        print('saving ' + img_name + ' ...')
        imgroot = dataset.root
        show_bbox(imgroot + '/' + img_name,
                  boxes,
                  ids,
                  CLASS_NAME,
                  file_name=img_name,
                  scores=scores)

        categories = [dataset.id2category[i] for i in ids]
        for k, box in enumerate(boxes):
            category_id = categories[k]
            score = scores[k]
Exemple #11
0
def train(args, dataset, gen, dis, g_ema, device):
    if args.distributed:
        g_module = gen.module
        d_module = dis.module

    else:
        g_module = gen
        d_module = dis

    vgg = VGGFeature("vgg16", [4, 9, 16, 23, 30],
                     use_fc=True).eval().to(device)
    requires_grad(vgg, False)

    g_optim = optim.Adam(gen.parameters(), lr=1e-4, betas=(0, 0.999))
    d_optim = optim.Adam(dis.parameters(), lr=1e-4, betas=(0, 0.999))

    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        num_workers=4,
        sampler=dist.data_sampler(dataset,
                                  shuffle=True,
                                  distributed=args.distributed),
        drop_last=True,
    )

    loader_iter = sample_data(loader)

    pbar = range(args.start_iter, args.iter)

    if dist.get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True)

    eps = 1e-8

    for i in pbar:
        real, class_id = next(loader_iter)

        real = real.to(device)
        class_id = class_id.to(device)

        masks = make_mask(real.shape[0], device, args.crop_prob)
        features, fcs = vgg(real)
        features = features + fcs[1:]

        requires_grad(dis, True)
        requires_grad(gen, False)

        real_pred = dis(real, class_id)

        z = torch.randn(args.batch, args.dim_z, device=device)

        fake = gen(z, class_id, features, masks)

        fake_pred = dis(fake, class_id)

        d_loss = d_ls_loss(real_pred, fake_pred)

        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()

        z1 = torch.randn(args.batch, args.dim_z, device=device)
        z2 = torch.randn(args.batch, args.dim_z, device=device)

        requires_grad(gen, True)
        requires_grad(dis, False)

        masks = make_mask(real.shape[0], device, args.crop_prob)

        if args.distributed:
            gen.broadcast_buffers = True

        fake1 = gen(z1, class_id, features, masks)

        if args.distributed:
            gen.broadcast_buffers = False

        fake2 = gen(z2, class_id, features, masks)

        fake_pred = dis(fake1, class_id)

        a_loss = g_ls_loss(None, fake_pred)

        features_fake, fcs_fake = vgg(fake1)
        features_fake = features_fake + fcs_fake[1:]

        r_loss = recon_loss(features_fake, features, masks)
        div_loss = diversity_loss(z1, z2, fake1, fake2, eps)

        g_loss = a_loss + args.rec_weight * r_loss + args.div_weight * div_loss

        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        accumulate(g_ema, g_module)

        if dist.get_rank() == 0:
            pbar.set_description(
                f"d: {d_loss.item():.4f}; g: {a_loss.item():.4f}; rec: {r_loss.item():.4f}; div: {div_loss.item():.4f}"
            )

            if i % 100 == 0:
                utils.save_image(
                    fake1,
                    f"sample/{str(i).zfill(6)}.png",
                    nrow=int(args.batch**0.5),
                    normalize=True,
                    range=(-1, 1),
                )

            if i % 10000 == 0:
                torch.save(
                    {
                        "args": args,
                        "g_ema": g_ema.state_dict(),
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
Exemple #12
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)

        real_pred = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}"
            ))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        f"sample/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
Exemple #13
0
def train(opt):
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)
    
    opt.num_class = len(converter.character)

    
    # styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    # genModel = AdaIN_Tensor_WordGenerator(opt)
    # disModel = MsImageDisV2(opt)

    # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none')
    # mixModel = Mixer(opt,nblk=3, dim=opt.latent)
    genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device)
    g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    ocrModel = ModelV1(opt).to(device)
    accumulate(g_ema, genModel, 0)

    # #  weight initialization
    # for currModel in [styleModel, mixModel]:
    #     for name, param in currModel.named_parameters():
    #         if 'localization_fc2' in name:
    #             print(f'Skip {name} as it is already initialized')
    #             continue
    #         try:
    #             if 'bias' in name:
    #                 init.constant_(param, 0.0)
    #             elif 'weight' in name:
    #                 init.kaiming_normal_(param)
    #         except Exception as e:  # for batchnorm.
    #             if 'weight' in name:
    #                 param.data.fill_(1)
    #             continue

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    # vggRecCriterion = torch.nn.L1Loss()
    # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length)

    if opt.distributed:
        genModel = torch.nn.parallel.DistributedDataParallel(
            genModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        
        disModel = torch.nn.parallel.DistributedDataParallel(
            disModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        ocrModel = torch.nn.parallel.DistributedDataParallel(
            ocrModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False
        )
    
    # styleModel = torch.nn.DataParallel(styleModel).to(device)
    # styleModel.train()
    
    # mixModel = torch.nn.DataParallel(mixModel).to(device)
    # mixModel.train()
    
    # genModel = torch.nn.DataParallel(genModel).to(device)
    # g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    # disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    # vggModel = torch.nn.DataParallel(vggModel).to(device)
    # vggModel.eval()

    # ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    # if opt.distributed:
    #     ocrModel.module.Transformation.eval()
    #     ocrModel.module.FeatureExtraction.eval()
    #     ocrModel.module.AdaptiveAvgPool.eval()
    #     # ocrModel.module.SequenceModeling.eval()
    #     ocrModel.module.Prediction.eval()
    # else:
    #     ocrModel.Transformation.eval()
    #     ocrModel.FeatureExtraction.eval()
    #     ocrModel.AdaptiveAvgPool.eval()
    #     # ocrModel.SequenceModeling.eval()
    #     ocrModel.Prediction.eval()
    ocrModel.eval()

    if opt.distributed:
        g_module = genModel.module
        d_module = disModel.module
    else:
        g_module = genModel
        d_module = disModel

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        genModel.parameters(),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        if not opt.distributed:
            ocrModel = torch.nn.DataParallel(ocrModel)
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
        #temporary fix
        if not opt.distributed:
            ocrModel = ocrModel.module
    
    if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
        print(f'loading pretrained gen model from {opt.saved_gen_model}')
        checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
        genModel.module.load_state_dict(checkpoint['g'])
        g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disModel.load_state_dict(checkpoint['disModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
       
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
        
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, False)
        # requires_grad(styleModel, False)
        # requires_grad(mixModel, False)
        requires_grad(disModel, True)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        
        #forward pass from style and word generator
        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device)
        # scInput = mixModel(style,text_2)
        if 'CTC' in opt.Prediction:
            images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent)
        else:
            images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent)
        
        #Domain discriminator: Dis update
        if opt.augment:
            image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p)
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        else:
            image_gt_tensors_aug = image_gt_tensors

        fake_pred = disModel(images_recon_2)
        real_pred = disModel(image_gt_tensors_aug)
        disCost = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = disCost*opt.disWeight
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        loss_avg_dis.add(disCost)

        disModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        if opt.augment and opt.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device
            )
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > opt.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            image_gt_tensors.requires_grad = True
            image_input_tensors.requires_grad = True
            cat_tensor = image_gt_tensors
            real_pred = disModel(cat_tensor)
            
            r1_loss = d_r1_loss(real_pred, cat_tensor)

            disModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()

        loss_dict["r1"] = r1_loss

        
        # #[Style Encoder] + [Word Generator] update
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, True)
        # requires_grad(styleModel, True)
        # requires_grad(mixModel, True)
        requires_grad(disModel, False)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        # scInput = mixModel(style,text_2)

        # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent)
        style = mixing_noise(batch_size, opt.latent, opt.mixing, device)
        
        if 'CTC' in opt.Prediction:
            images_recon_2, _ = genModel(style, text_2)
        else:
            images_recon_2, _ = genModel(style, text_2[:,1:-1])

        if opt.augment:
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        fake_pred = disModel(images_recon_2)
        disGenCost = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = disGenCost

        # # #Adversarial loss
        # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1))

        # #Input reconstruction loss
        # recCost = recCriterion(images_recon_2,image_gt_tensors)

        # #vgg loss
        # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)
        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                
                preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False)
                # preds_o = preds_recon[:, :text_1.shape[1], :]
                preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2)
                
                #predict ocr recognition on generated images
                # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                _, preds_recon_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                # preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data)

            else:
                preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        # cost =  opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost
        cost =  opt.disWeight*disGenCost + opt.ocrWeight*ocrCost

        # styleModel.zero_grad()
        genModel.zero_grad()
        # mixModel.zero_grad()
        disModel.zero_grad()
        # vggModel.zero_grad()
        ocrModel.zero_grad()
        
        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
            image_input_tensors = image_input_tensors.to(device)
            image_gt_tensors = image_gt_tensors.to(device)
            batch_size = image_input_tensors.size(0)

            text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
            text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

            path_batch_size = max(1, batch_size // opt.path_batch_shrink)

            # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
            # scInput = mixModel(style,text_2)

            # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True)

            style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device)
            
            
            if 'CTC' in opt.Prediction:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True)
            else:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True)
            
            
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                images_recon_2, latents, mean_path_length
            )

            genModel.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()


        #Individual losses
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(torch.tensor(0.0))
        loss_avg_vgg_per.add(torch.tensor(0.0))
        loss_avg_vgg_sty.add(torch.tensor(0.0))
        loss_avg_ocr.add(opt.ocrWeight*ocrCost)

        log_r1_val.add(loss_reduced["path"])
        log_avg_path_loss_val.add(loss_reduced["path"])
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        
        if get_rank() == 0:
            # pbar.set_description(
            #     (
            #         f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
            #         f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            #         f"augment: {ada_aug_p:.4f}"
            #     )
            # )

            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
            # if cntr % 100 == 0:
            #     with torch.no_grad():
            #         g_ema.eval()
            #         sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]])
            #         utils.save_image(
            #             sample,
            #             os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"),
            #             nrow=int(opt.n_sample ** 0.5),
            #             normalize=True,
            #             range=(-1, 1),
            #         )


        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            curr_batch_size = style[0].shape[0]
            images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent)
            
            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png'))
                    else:
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                # styleModel.eval()
                genModel.eval()
                g_ema.eval()
                # mixModel.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, infer_time, length_of_data = validation_synth_v6(
                        iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt)
                
                # styleModel.train()
                genModel.train()
                # mixModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train OCR loss: {loss_avg_ocr.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item())
                
                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                # 'styleModel':styleModel.state_dict(),
                # 'mixModel':mixModel.state_dict(),
                'genModel':g_module.state_dict(),
                'g_ema':g_ema.state_dict(),
                'disModel':d_module.state_dict(),
                'optimizer':optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Exemple #14
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device, save_dir):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += reduce_sum(ada_augment_data)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 1000 == 0:  # save some samples

                with torch.no_grad():
                    g_ema.eval()

                    sample, _ = g_ema([sample_z])

                    utils.save_image(
                        sample,
                        save_dir + f"/samples/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 2000 == 0:  #save the model
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    save_dir + f"/checkpoints/{str(i).zfill(6)}.pt",
                )
Exemple #15
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    sample_labels = []
    while len(sample_labels) < args.n_sample:
        real_img, real_label = next(loader)
        sample_labels.append(real_label.to(device))
    sample_labels = torch.cat(sample_labels, 0)[:args.n_sample]

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print('Done!')

            break

        real_img, real_label = next(loader)
        real_img = real_img.to(device)
        real_label = real_label.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(real_label, noise)
        fake_pred = discriminator(real_label, fake_img)

        real_pred = discriminator(real_label, real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_label, real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(real_label, noise)
        fake_pred = discriminator(real_label, fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict['g'] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(real_label[:path_batch_size],
                                          noise,
                                          return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced['d'].mean().item()
        g_loss_val = loss_reduced['g'].mean().item()
        r1_val = loss_reduced['r1'].mean().item()
        path_loss_val = loss_reduced['path'].mean().item()
        real_score_val = loss_reduced['real_score'].mean().item()
        fake_score_val = loss_reduced['fake_score'].mean().item()
        path_length_val = loss_reduced['path_length'].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; '
                f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}'
            ))

            if wandb and args.wandb:
                wandb.log({
                    'Generator': g_loss_val,
                    'Discriminator': d_loss_val,
                    'R1': r1_val,
                    'Path Length Regularization': path_loss_val,
                    'Mean Path Length': mean_path_length,
                    'Real Score': real_score_val,
                    'Fake Score': fake_score_val,
                    'Path Length': path_length_val,
                })

            if i % 200 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema(sample_labels, [sample_z])
                    utils.save_image(
                        sample,
                        f'sample/{str(i).zfill(6)}.png',
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        'g': g_module.state_dict(),
                        'd': d_module.state_dict(),
                        'g_ema': g_ema.state_dict(),
                        'g_optim': g_optim.state_dict(),
                        'd_optim': d_optim.state_dict(),
                    },
                    f'checkpoint/{str(i).zfill(6)}.pt',
                )
def train(args, loader, generator, discriminator, extra, g_optim, d_optim,
          e_optim, g_ema, device, g_source, d_source):
    loader = sample_data(loader)

    imsave_path = os.path.join('samples', args.exp)
    model_path = os.path.join('checkpoints', args.exp)

    if not os.path.exists(imsave_path):
        os.makedirs(imsave_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # this defines the anchor points, and when sampling noise close to these, we impose image-level adversarial loss (Eq. 4 in the paper)
    init_z = torch.randn(args.n_train, args.latent, device=device)
    pbar = range(args.iter)
    sfm = nn.Softmax(dim=1)
    kl_loss = nn.KLDivLoss()
    sim = nn.CosineSimilarity()
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    g_module = generator
    d_module = discriminator
    g_ema_module = g_ema.module

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    # this defines which level feature of the discriminator is used to implement the patch-level adversarial loss: could be anything between [0, args.highp]
    lowp, highp = 0, args.highp

    # the following defines the constant noise used for generating images at different stages of training
    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    requires_grad(g_source, False)
    requires_grad(d_source, False)
    sub_region_z = get_subspace(args, init_z.clone(), vis_flag=True)
    for idx in pbar:
        i = idx + args.start_iter
        which = i % args.subspace_freq  # defines whether we sample from anchor region in this iteration or other

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)
        requires_grad(extra, True)

        if which > 0:
            # sample normally, apply patch-level adversarial loss
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        else:
            # sample from anchors, apply image-level adversarial loss
            noise = [get_subspace(args, init_z.clone())]

        fake_img, _ = generator(noise)

        if args.augment:
            real_img, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred, _ = discriminator(fake_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp))
        real_pred, _ = discriminator(real_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp),
                                     real=True)

        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        extra.zero_grad()
        d_loss.backward()
        d_optim.step()
        e_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred, _ = discriminator(real_img,
                                         extra=extra,
                                         flag=which,
                                         p_ind=np.random.randint(lowp, highp))
            real_pred = real_pred.view(real_img.size(0), -1)
            real_pred = real_pred.mean(dim=1).unsqueeze(1)

            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            extra.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()
            e_optim.step()
        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)
        requires_grad(extra, False)
        if which > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        else:
            noise = [get_subspace(args, init_z.clone())]

        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred, _ = discriminator(fake_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp))
        g_loss = g_nonsaturating_loss(fake_pred)

        # distance consistency loss
        with torch.set_grad_enabled(False):
            z = torch.randn(args.feat_const_batch, args.latent, device=device)
            feat_ind = numpy.random.randint(1,
                                            g_source.module.n_latent - 1,
                                            size=args.feat_const_batch)

            # computing source distances
            source_sample, feat_source = g_source([z], return_feats=True)
            dist_source = torch.zeros(
                [args.feat_const_batch, args.feat_const_batch - 1]).cuda()

            # iterating over different elements in the batch
            for pair1 in range(args.feat_const_batch):
                tmpc = 0
                # comparing the possible pairs
                for pair2 in range(args.feat_const_batch):
                    if pair1 != pair2:
                        anchor_feat = torch.unsqueeze(
                            feat_source[feat_ind[pair1]][pair1].reshape(-1), 0)
                        compare_feat = torch.unsqueeze(
                            feat_source[feat_ind[pair1]][pair2].reshape(-1), 0)
                        dist_source[pair1, tmpc] = sim(anchor_feat,
                                                       compare_feat)
                        tmpc += 1
            dist_source = sfm(dist_source)

        # computing distances among target generations
        _, feat_target = generator([z], return_feats=True)
        dist_target = torch.zeros(
            [args.feat_const_batch, args.feat_const_batch - 1]).cuda()

        # iterating over different elements in the batch
        for pair1 in range(args.feat_const_batch):
            tmpc = 0
            for pair2 in range(
                    args.feat_const_batch):  # comparing the possible pairs
                if pair1 != pair2:
                    anchor_feat = torch.unsqueeze(
                        feat_target[feat_ind[pair1]][pair1].reshape(-1), 0)
                    compare_feat = torch.unsqueeze(
                        feat_target[feat_ind[pair1]][pair2].reshape(-1), 0)
                    dist_target[pair1, tmpc] = sim(anchor_feat, compare_feat)
                    tmpc += 1
        dist_target = sfm(dist_target)
        rel_loss = args.kl_wt * \
            kl_loss(torch.log(dist_target), dist_source) # distance consistency loss
        g_loss = g_loss + rel_loss

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        # to save up space
        del rel_loss, g_loss, d_loss, fake_img, fake_pred, real_img, real_pred, anchor_feat, compare_feat, dist_source, dist_target, feat_source, feat_target

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema_module, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.img_freq == 0:
                with torch.set_grad_enabled(False):
                    g_ema.eval()
                    sample, _ = g_ema([sample_z.data])
                    sample_subz, _ = g_ema([sub_region_z.data])
                    utils.save_image(
                        sample,
                        f"%s/{str(i).zfill(6)}.png" % (imsave_path),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )
                    del sample

            if (i % args.save_freq == 0) and (i > 0):
                torch.save(
                    {
                        "g_ema": g_ema.state_dict(),
                        # uncomment the following lines only if you wish to resume training after saving. Otherwise, saving just the generator is sufficient for evaluations

                        #"g": g_module.state_dict(),
                        #"g_s": g_source.state_dict(),
                        #"d": d_module.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                    },
                    f"%s/{str(i).zfill(6)}.pt" % (model_path),
                )
Exemple #17
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)
    current_ckpt = args.current_ckpt

    # generate one fake image to check data correct
    test_imgs = next(loader)
    real_grid = utils.make_grid(test_imgs,
                                nrow=2,
                                normalize=True,
                                range=(-1, 1))
    wandb.log({"reals": [wandb.Image(real_grid, caption='Real Data')]})

    pbar = tqdm(dynamic_ncols=True,
                smoothing=0.01,
                initial=current_ckpt + 1,
                total=args.iter)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    none_g_grads = set()
    test_in = torch.randn(1, args.latent, device=device)
    fake, latent = g_module([test_in], return_latents=True)
    path = g_path_regularize(fake, latent, 0)
    path[0].backward()

    for n, p in generator.named_parameters():
        if p.grad is None:
            none_g_grads.add(n)

    test_in = torch.randn(1,
                          3,
                          args.size,
                          args.size,
                          requires_grad=True,
                          device=device)
    pred = d_module(test_in)
    r1_loss = d_r1_loss(pred, test_in)
    r1_loss.backward()

    none_d_grads = set()
    for n, p in discriminator.named_parameters():
        if p.grad is None:
            none_d_grads.add(n)

    seed = torch.initial_seed() % 10000000
    torch.manual_seed(20)
    torch.cuda.manual_seed_all(20)
    sample_z = torch.randn(4 * 4, args.latent, device=device)
    sample_z_chunks = torch.split(sample_z, args.batch)

    # reset seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    i = current_ckpt + 1
    while i < args.iter:
        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)

        real_pred = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            set_grad_none(discriminator, none_d_grads)

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict['g'] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            noise = mixing_noise(args.batch // args.path_batch_shrink,
                                 args.latent, args.mixing, device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()
            set_grad_none(g_module, none_g_grads)

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced['d'].mean().item()
        g_loss_val = loss_reduced['g'].mean().item()
        r1_val = loss_reduced['r1'].mean().item()
        path_loss_val = loss_reduced['path'].mean().item()
        real_score_val = loss_reduced['real_score'].mean().item()
        fake_score_val = loss_reduced['fake_score'].mean().item()
        path_length_val = loss_reduced['path_length'].mean().item()

        if get_rank() == 0:
            pbar.set_postfix(d_loss=f'{d_loss_val:.4f}',
                             g_loss=f'{g_loss_val:.4f}',
                             r1_loss=f'{r1_val:.4f}',
                             path=f'{path_loss_val:.4f}',
                             mean=f'{mean_path_length_avg:.4f}')

            if wandb and args.wandb:
                wandb.log({
                    'Generator': g_loss_val,
                    'Discriminator': d_loss_val,
                    'R1': r1_val,
                    'Path Length Regularization': path_loss_val,
                    'Mean Path Length': mean_path_length,
                    'Real Score': real_score_val,
                    'Fake Score': fake_score_val,
                    'Path Length': path_length_val,
                    'current_ckpt': current_ckpt,
                    'iteration': i,
                })

            if i % 500 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample = generate_fake_images(g_ema, sample_z_chunks)
                    if wandb and args.wandb:
                        label = f'{str(i).zfill(8)}.png'
                        image = utils.make_grid(sample,
                                                nrow=4,
                                                normalize=True,
                                                range=(-1, 1))
                        wandb.log(
                            {"samples": [wandb.Image(image, caption=label)]})
                    else:
                        utils.save_image(
                            sample,
                            f'sample/{str(i).zfill(8)}.png',
                            nrow=8,
                            normalize=True,
                            range=(-1, 1),
                        )

            if i % 2000 == 0:
                ckpt_name = f'checkpoint/{str(i).zfill(8)}.pt'
                # remove the previous checkpoint
                shutil.rmtree('checkpoint')
                os.mkdir('checkpoint')
                torch.save(
                    {
                        'g': g_module.state_dict(),
                        'd': d_module.state_dict(),
                        'g_ema': g_ema.state_dict(),
                        'g_optim': g_optim.state_dict(),
                        'd_optim': d_optim.state_dict(),
                    },
                    ckpt_name,
                )
                current_ckpt = i
                if wandb and args.wandb:
                    wandb.save(ckpt_name)
            i = i + 1
            pbar.update()
Exemple #18
0
        generator.proj.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio / 2,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
Exemple #19
0
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                broadcast_buffers=False,
                find_unused_parameters=True,
            )

    transform = transforms.Compose(
        [
            transforms.RandomVerticalFlip(p=0.5 if args.vflip else 0),
            transforms.RandomHorizontalFlip(p=0.5 if args.hflip else 0),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
        num_workers=8,
        drop_last=True,
    )

    if get_rank() == 0:
        validation.get_dataset_inception_features(loader, args.name, args.size)
        wandb.init(project=f"maua-stylegan", name="Cyphept Correct BCR", config=vars(args))
    scaler = th.cuda.amp.GradScaler()

    train(args, loader, generator, discriminator, contrast_learner, augment_fn, g_optim, d_optim, scaler, g_ema, device)
def train(opt):
    lib.print_model_settings(locals().copy())

    # train_transform =  transforms.Compose([
    #     # transforms.RandomResizedCrop(input_size),
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    # val_transform = transforms.Compose([
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.CenterCrop(input_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    AlignFontCollateObj = AlignFontCollate(imgH=opt.imgH,
                                           imgW=opt.imgW,
                                           keep_ratio_with_pad=opt.PAD)
    train_dataset = fontDataset(imgDir=opt.train_img_dir,
                                annFile=opt.train_ann_file,
                                transform=None,
                                numClasses=opt.numClasses)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(train_dataset,
                             shuffle=True,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    # numClasses = len(train_dataset.Idx2F)
    numClasses = np.unique(train_dataset.fontIdx).size

    train_loader = sample_data(train_loader)
    print('-' * 80)
    numTrainSamples = len(train_dataset)

    # valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_dataset = fontDataset(imgDir=opt.train_img_dir,
                                annFile=opt.val_ann_file,
                                transform=None,
                                F2Idx=train_dataset.F2Idx,
                                Idx2F=train_dataset.Idx2F,
                                numClasses=opt.numClasses)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(valid_dataset,
                             shuffle=False,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    numTestSamples = len(valid_dataset)

    print('numClasses', numClasses)
    print('numTrainSamples', numTrainSamples)
    print('numTestSamples', numTestSamples)

    vggFontModel = VGGFontModel(models.vgg19(pretrained=opt.preTrained),
                                numClasses).to(device)
    for name, param in vggFontModel.classifier.named_parameters():
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            print('Exception in weight init' + name)
            if 'weight' in name:
                param.data.fill_(1)
            continue

    if opt.optim == "sgd":
        print('SGD optimizer')
        optimizer = optim.SGD(vggFontModel.parameters(),
                              lr=opt.lr,
                              momentum=0.9)
    elif opt.optim == "adam":
        print('Adam optimizer')
        optimizer = optim.Adam(vggFontModel.parameters(), lr=opt.lr)
    #get schedulers
    scheduler = get_scheduler(optimizer, opt)

    criterion = torch.nn.CrossEntropyLoss()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_vggFont.pth"))) > 0:
            opt.saved_font_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_vggFont.pth"))[-1]

    ## Loading pre-trained files
    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_font_model}')
        checkpoint = torch.load(opt.saved_font_model,
                                map_location=lambda storage, loc: storage)

        vggFontModel.load_state_dict(checkpoint['vggFontModel'])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    # print('Model Initialization')
    #
    # print('Loaded checkpoint')

    if opt.distributed:
        vggFontModel = torch.nn.parallel.DistributedDataParallel(
            vggFontModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True)
        vggFontModel.train()

    # print('Loaded distributed')

    if opt.distributed:
        vggFontModel_module = vggFontModel.module
    else:
        vggFontModel_module = vggFontModel

    # print('Loading module')

    # loss averager
    loss_train = Averager()
    loss_val = Averager()
    train_acc = Averager()
    val_acc = Averager()
    train_acc_5 = Averager()
    val_acc_5 = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        try:
            start_iter = int(opt.saved_font_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    iteration = start_iter

    cntr = 0
    # trainCorrect=0
    # tCntr=0
    while (True):
        # print(cntr)
        # train part

        start_time = time.time()
        if not opt.testFlag:

            image_input_tensors, labels_gt = next(train_loader)
            image_input_tensors = image_input_tensors.to(device)
            labels_gt = labels_gt.view(-1).to(device)
            preds = vggFontModel(image_input_tensors)

            loss = criterion(preds, labels_gt)

            vggFontModel.zero_grad()
            loss.backward()
            optimizer.step()

            # _, preds_max = preds.max(dim=1)
            # trainCorrect += (preds_max == labels_gt).sum()
            # tCntr+=preds_max.shape[0]

            acc1, acc5 = getNumCorrect(preds,
                                       labels_gt,
                                       topk=(1, min(numClasses, 5)))
            train_acc.addScalar(acc1, preds.shape[0])
            train_acc_5.addScalar(acc5, preds.shape[0])

            loss_train.add(loss)

            if opt.lr_policy != "None":
                scheduler.step()

        # print
        if get_rank() == 0:
            if (
                    iteration + 1
            ) % opt.valInterval == 0 or iteration == 0 or opt.testFlag:  # To see training progress, we also conduct validation when 'iteration == 0'
                #validation
                # iCntr=torch.tensor(0.0).to(device)
                # valCorrect=torch.tensor(0.0).to(device)
                vggFontModel.eval()
                print('Inside val', iteration)

                for vCntr, (image_input_tensors,
                            labels_gt) in enumerate(valid_loader):
                    # print('vCntr--',vCntr)
                    if opt.debugFlag and vCntr > 2:
                        break

                    with torch.no_grad():
                        image_input_tensors = image_input_tensors.to(device)
                        labels_gt = labels_gt.view(-1).to(device)

                        preds = vggFontModel(image_input_tensors)
                        loss = criterion(preds, labels_gt)
                        loss_val.add(loss)

                        # _, preds_max = preds.max(dim=1)
                        # valCorrect += (preds_max == labels_gt).sum()
                        # iCntr+=preds_max.shape[0]

                        acc1, acc5 = getNumCorrect(preds,
                                                   labels_gt,
                                                   topk=(1, min(numClasses,
                                                                5)))
                        val_acc.addScalar(acc1, preds.shape[0])
                        val_acc_5.addScalar(acc5, preds.shape[0])

                vggFontModel.train()
                elapsed_time = time.time() - start_time

                #DO HERE
                with open(
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'log_train.txt'), 'a') as log:
                    # print('COUNT-------',val_acc_5.n_count)
                    # training loss and validation loss
                    loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                        Train loss: {loss_train.val():0.5f}, Val loss: {loss_val.val():0.5f}, \
                        Train Top-1 Acc: {train_acc.val()*100:0.5f}, Train Top-5 Acc: {train_acc_5.val()*100:0.5f}, \
                        Val Top-1 Acc: {val_acc.val()*100:0.5f}, Val Top-5 Acc: {val_acc_5.val()*100:0.5f}, \
                        Elapsed_time: {elapsed_time:0.5f}'

                    #plotting
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Loss'),
                                  loss_train.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Loss'),
                                  loss_val.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-1-Acc'),
                                  train_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-5-Acc'),
                                  train_acc_5.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-1-Acc'),
                                  val_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-5-Acc'),
                                  val_acc_5.val() * 100)

                    print(loss_log)
                    log.write(loss_log + "\n")

                    loss_train.reset()
                    loss_val.reset()
                    train_acc.reset()
                    val_acc.reset()
                    train_acc_5.reset()
                    val_acc_5.reset()
                    # trainCorrect=0
                    # tCntr=0

                lib.plot.flush()

            # save model per 30000 iter.
            if (iteration) % 15000 == 0:
                torch.save(
                    {
                        'vggFontModel': vggFontModel_module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    },
                    os.path.join(opt.exp_dir, opt.exp_name, 'iter_' +
                                 str(iteration + 1) + '_vggFont.pth'))

            lib.plot.tick()

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1
Exemple #21
0
def train(args, loader, generator, discriminator, contrast_learner, g_optim,
          d_optim, g_ema):
    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        if contrast_learner is not None:
            cl_module = contrast_learner.module
    else:
        g_module = generator
        d_module = discriminator
        cl_module = contrast_learner

    loader = sample_data(loader)
    sample_z = th.randn(args.n_sample, args.latent_size, device=device)
    mse = th.nn.MSELoss()
    mean_path_length = 0
    ada_augment = th.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0
    fids = []

    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0)
    for idx in pbar:
        i = idx + args.start_iter
        if i > args.iter:
            print("Done!")
            break

        loss_dict = {
            "Generator": th.tensor(0, device=device).float(),
            "Discriminator": th.tensor(0, device=device).float(),
            "Real Score": th.tensor(0, device=device).float(),
            "Fake Score": th.tensor(0, device=device).float(),
            "Contrastive": th.tensor(0, device=device).float(),
            "Consistency": th.tensor(0, device=device).float(),
            "R1 Penalty": th.tensor(0, device=device).float(),
            "Path Length Regularization": th.tensor(0, device=device).float(),
            "Augment": th.tensor(0, device=device).float(),
            "Rt": th.tensor(0, device=device).float(),
        }

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        discriminator.zero_grad()
        for _ in range(args.num_accumulate):
            real_img_og = next(loader).to(device)
            noise = make_noise(args.batch_size, args.latent_size,
                               args.mixing_prob)
            fake_img_og, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img_og, ada_aug_p)
                real_img, _ = augment(real_img_og, ada_aug_p)
            else:
                fake_img = fake_img_og
                real_img = real_img_og

            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img)
            logistic_loss = d_logistic_loss(real_pred, fake_pred)
            loss_dict["Discriminator"] += logistic_loss.detach()
            loss_dict["Real Score"] += real_pred.mean().detach()
            loss_dict["Fake Score"] += fake_pred.mean().detach()
            d_loss = logistic_loss

            if args.contrastive > 0:
                contrast_learner(fake_img_og, fake_img, accumulate=True)
                contrast_learner(real_img_og, real_img, accumulate=True)
                contrast_loss = cl_module.calculate_loss()
                loss_dict["Contrastive"] += contrast_loss.detach()
                d_loss += args.contrastive * contrast_loss

            if args.balanced_consistency > 0:
                consistency_loss = mse(
                    real_pred, discriminator(real_img_og)) + mse(
                        fake_pred, discriminator(fake_img_og))
                loss_dict["Consistency"] += consistency_loss.detach()
                d_loss += args.balanced_consistency * consistency_loss

            d_loss /= args.num_accumulate
            d_loss.backward()
        d_optim.step()

        if args.r1 > 0 and i % args.d_reg_every == 0:
            discriminator.zero_grad()
            for _ in range(args.num_accumulate):
                real_img = next(loader).to(device)
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss = d_r1_penalty(real_img, real_pred, args)
                loss_dict["R1 Penalty"] += r1_loss.detach().squeeze()
                r1_loss = args.r1 * args.d_reg_every * r1_loss / args.num_accumulate
                r1_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment += th.tensor(
                (th.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred
                loss_dict["Rt"] = th.tensor(r_t_stat, device=device).float()
                if r_t_stat > args.ada_target:
                    sign = 1
                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)
                loss_dict["Augment"] = th.tensor(ada_aug_p,
                                                 device=device).float()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        generator.zero_grad()
        for _ in range(args.num_accumulate):
            noise = make_noise(args.batch_size, args.latent_size,
                               args.mixing_prob)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss = g_non_saturating_loss(fake_pred)
            loss_dict["Generator"] += g_loss.detach()
            g_loss /= args.num_accumulate
            g_loss.backward()
        g_optim.step()

        if args.path_regularize > 0 and i % args.g_reg_every == 0:
            generator.zero_grad()
            for _ in range(args.num_accumulate):
                path_loss, mean_path_length = g_path_length_regularization(
                    generator, mean_path_length, args)
                loss_dict["Path Length Regularization"] += path_loss.detach()
                path_loss = args.path_regularize * args.g_reg_every * path_loss / args.num_accumulate
                path_loss.backward()
            g_optim.step()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)
        log_dict = {
            k: v.mean().item() / args.num_accumulate
            for k, v in loss_reduced.items() if v != 0
        }
        if get_rank() == 0:
            if args.log_spec_norm:
                G_norms = []
                for name, spec_norm in g_module.named_buffers():
                    if "spectral_norm" in name:
                        G_norms.append(spec_norm.cpu().numpy())
                G_norms = np.array(G_norms)
                D_norms = []
                for name, spec_norm in d_module.named_buffers():
                    if "spectral_norm" in name:
                        D_norms.append(spec_norm.cpu().numpy())
                D_norms = np.array(D_norms)
                log_dict[f"Spectral Norms/G min spectral norm"] = np.log(
                    G_norms).min()
                log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(
                    G_norms).mean()
                log_dict[f"Spectral Norms/G max spectral norm"] = np.log(
                    G_norms).max()
                log_dict[f"Spectral Norms/D min spectral norm"] = np.log(
                    D_norms).min()
                log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(
                    D_norms).mean()
                log_dict[f"Spectral Norms/D max spectral norm"] = np.log(
                    D_norms).max()

            if i % args.img_every == 0:
                gc.collect()
                th.cuda.empty_cache()
                with th.no_grad():
                    g_ema.eval()
                    sample = []
                    for sub in range(0, len(sample_z), args.batch_size):
                        subsample, _ = g_ema(
                            [sample_z[sub:sub + args.batch_size]])
                        sample.append(subsample.cpu())
                    sample = th.cat(sample)
                    grid = utils.make_grid(sample,
                                           nrow=10,
                                           normalize=True,
                                           range=(-1, 1))
                log_dict["Generated Images EMA"] = [
                    wandb.Image(grid, caption=f"Step {i}")
                ]

            if i % args.eval_every == 0:
                fid_dict = validation.fid(g_ema, args.val_batch_size,
                                          args.fid_n_sample,
                                          args.fid_truncation, args.name)

                fid = fid_dict["FID"]
                fids.append(fid)
                density = fid_dict["Density"]
                coverage = fid_dict["Coverage"]

                ppl = validation.ppl(
                    g_ema,
                    args.val_batch_size,
                    args.ppl_n_sample,
                    args.ppl_space,
                    args.ppl_crop,
                    args.latent_size,
                )

                log_dict["Evaluation/FID"] = fid
                log_dict["Sweep/FID_smooth"] = gaussian_filter(
                    np.array(fids), [5])[-1]
                log_dict["Evaluation/Density"] = density
                log_dict["Evaluation/Coverage"] = coverage
                log_dict["Evaluation/PPL"] = ppl

                gc.collect()
                th.cuda.empty_cache()

            wandb.log(log_dict)
            description = (
                f"FID: {fid:.4f}   PPL: {ppl:.4f}   Dens: {density:.4f}   Cov: {coverage:.4f}   "
                +
                f"G: {log_dict['Generator']:.4f}   D: {log_dict['Discriminator']:.4f}"
            )
            if "Augment" in log_dict:
                description += f"   Aug: {log_dict['Augment']:.4f}"  #   Rt: {log_dict['Rt']:.4f}"
            if "R1 Penalty" in log_dict:
                description += f"   R1: {log_dict['R1 Penalty']:.4f}"
            if "Path Length Regularization" in log_dict:
                description += f"   Path: {log_dict['Path Length Regularization']:.4f}"
            pbar.set_description(description)

            if i % args.checkpoint_every == 0:
                check_name = "-".join([
                    args.name,
                    args.runname,
                    wandb.run.dir.split("/")[-1].split("-")[-1],
                    int(fid),
                    args.size,
                    str(i).zfill(6),
                ])
                th.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        # "cl": cl_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"/home/hans/modelzoo/maua-sg2/{check_name}.pt",
                )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device)

    args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet))
    args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet
    args.n_sample = args.n_sample_per_sheet * args.n_sheets
    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device)
    if args.n_sample > args.n_sample_per_class * args.n_classes:
        sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device)
        sample_y = torch.cat([sample_y, sample_y1], 0)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        for step_index in range(args.n_step_d):
            real_img, real_labels = next(loader)
            real_img, real_labels = real_img.to(device), real_labels.to(device)

            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, _ = generator(noise, fake_labels)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)

            else:
                real_img_aug = real_img

            fake_pred = discriminator(fake_img, fake_labels)
            real_pred = discriminator(real_img_aug, real_labels)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(real_img_aug, real_labels)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_labels = make_fake_label(args.batch, args.n_classes, device)
        fake_img, _ = generator(noise, fake_labels)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img, fake_labels)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, latents = generator(noise, fake_labels, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update G_ema
        # G_ema = G * (1-ema_beta) + G_ema * ema_beta
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                    f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                    f"augment: {ada_aug_p:.4f}"
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                        "Path Length": path_length_val,
                    }
                )
            
            if i % args.log_every == 0:
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        (
                            f"{i:07d}; "
                            f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                            f"augment: {ada_aug_p:.4f};\n"
                        )
                    )

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    for sheet_index in range(args.n_sheets):
                        sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample, _ = g_ema([sample_z_sheet], sample_y_sheet)
                        utils.save_image(
                            sample,
                            os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"),
                            nrow=args.n_sample_per_class,
                            normalize=True,
                            value_range=(-1, 1),
                        )
            
            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device,
                        n_classes=args.n_classes,
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                # print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    
    valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    text_dataset = text_gen(opt)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
        pin_memory=True, drop_last=True)
    opt.num_class = len(converter.character)
    

    c_code_size = opt.latent
    cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size)
    ocrModel = ModelV1(opt)

    
    genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
   
    disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size)
    
    accumulate(g_ema, genModel, 0)
    
    # uCriterion = torch.nn.MSELoss()
    # sCriterion = torch.nn.MSELoss()
    # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
    #     ocrCriterion = torch.nn.L1Loss()
    # else:
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        print('Not implemented error')
        sys.exit()
        # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    cEncoder= torch.nn.DataParallel(cEncoder).to(device)
    cEncoder.train()
    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if opt.ocrFixed:
        if opt.Transformation == 'TPS':
            ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()
    else:
        ocrModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    
    optimizer = optim.Adam(
        list(genModel.parameters())+list(cEncoder.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )
    
    ocr_optimizer = optim.Adam(
        ocrModel.parameters(),
        lr=opt.lr,
        betas=(0.9, 0.99),
    )


    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
    
    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])
        ocrModel.load_state_dict(checkpoint['ocrModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    loss_avg_ocr_sup = Averager()
    loss_avg_ocr_unsup = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0
    epsilon = 10e-50
    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
            ocr_scheduler.step()
        
        image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size].detach()
        real_image_tensors = image_input_tensors[opt.batch_size:].detach()
        
        labels_gt = labels[:opt.batch_size]
        
        requires_grad(cEncoder, False)
        requires_grad(genModel, False)
        requires_grad(disEncModel, True)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)
        
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)
        
        # #unsupervised code prediction on generated image
        # u_pred_code = disEncModel(fake_img, mode='enc')
        # uCost = uCriterion(u_pred_code, z_code)

        # #supervised code prediction on gt image
        # s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        # sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost
        loss_avg_dis.add(disCost)
        # loss_avg_sup.add(opt.beta*sCost)
        # loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True
            real_pred = disEncModel(real_image_tensors)
            
            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()
        log_r1_val.add(r1_loss)
        
        # Recognizer update
        if not opt.ocrFixed and not opt.zAlone:
            requires_grad(disEncModel, False)
            requires_grad(ocrModel, True)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt)
            else:
                print("Not implemented error")
                sys.exit()
            
            ocrModel.zero_grad()
            ocrCost.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            loss_avg_ocr_sup.add(ocrCost)
        else:
            loss_avg_ocr_sup.add(torch.tensor(0.0))


        # [Word Generator] update
        # image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        # image_input_tensors = image_input_tensors.to(device)
        # gt_image_tensors = image_input_tensors[:opt.batch_size]
        # real_image_tensors = image_input_tensors[opt.batch_size:]
        
        # labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, True)
        requires_grad(genModel, True)
        requires_grad(disEncModel, False)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        
        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        if opt.zAlone:
            ocrCost = torch.tensor(0.0)
        else:
            #Compute OCR prediction (Reconstruction of content)
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)
            
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(fake_img, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c)
            else:
                print("Not implemented error")
                sys.exit()
        
        genModel.zero_grad()
        cEncoder.zero_grad()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
        
        if opt.grad_balance:
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0]
            a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR))
            if a is None:
                print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
            if a>1000 or a<0.0001:
                print(a)
            
            ocrCost = a.detach() * ocrCost
            gen_enc_cost = disGenCost + ocrCost
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0]
            loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
            loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
            with torch.no_grad():
                gen_enc_cost.backward()
        else:
            gen_enc_cost.backward()

        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        optimizer.step()
        
        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)
            # image_input_tensors, _, labels, _ = iter(train_loader).next()
            labels_z_c = iter(text_loader).next()

            # image_input_tensors = image_input_tensors.to(device)
            # gt_image_tensors = image_input_tensors[:path_batch_size]

            # labels_gt = labels[:path_batch_size]

            text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length)
            # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
            z_c_code = cEncoder(text_z_c)
            noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device)
            style=[]
            style.append(noise_style[0]*z_c_code)
            if len(noise_style)>1:
                style.append(noise_style[1]*z_c_code)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:,:opt.latent])
                if len(style)>1:
                    newstyle.append(style[1][:,:opt.latent])
                style = newstyle

            fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length)
            
            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            genModel.zero_grad()
            cEncoder.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
        
        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #generate paired content with similar style
            labels_z_c_1 = iter(text_loader).next()
            labels_z_c_2 = iter(text_loader).next()
            
            text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length)
            text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length)

            z_c_code_1 = cEncoder(text_z_c_1)
            z_c_code_2 = cEncoder(text_z_c_2)

            
            style_c1_s1 = []
            style_c2_s1 = []
            style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s1.append(style_s1[0]*z_c_code_1)
            style_c2_s1.append(style_s1[0]*z_c_code_2)
            if len(style_s1)>1:
                style_c1_s1.append(style_s1[1]*z_c_code_1)
                style_c2_s1.append(style_s1[1]*z_c_code_2)
            
            noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s2 = []
            style_c1_s2.append(noise_style[0]*z_c_code_1)
            if len(noise_style)>1:
                style_c1_s2.append(noise_style[1]*z_c_code_1)
            
            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_c1_s1[0][:,:opt.latent])
                if len(style_c1_s1)>1:
                    newstyle.append(style_c1_s1[1][:,:opt.latent])
                style_c1_s1 = newstyle
                style_c2_s1 = newstyle
                style_c1_s2 = newstyle
            
            fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent)
            fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent)

            if not opt.zAlone:
                #Run OCR prediction
                if 'CTC' in opt.Prediction:
                    preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(gt_image_tensors, text_gt, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0])
                    _, preds_index = preds.max(2)
                    preds_str_gt = converter.decode(preds_index.data, preds_size.data)

                else:
                    print("Not implemented error")
                    sys.exit()
            else:
                preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0]
                preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] 

            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'))
                    if not opt.zAlone:
                        save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'))
                        save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png'))
                        if trImgCntr<gt_image_tensors.shape[0]:
                            save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                
                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_ocr_unsup.reset()
                loss_avg_ocr_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                'cEncoder':cEncoder.state_dict(),
                'genModel':genModel.state_dict(),
                'g_ema':g_ema.state_dict(),
                'ocrModel':ocrModel.state_dict(),
                'disEncModel':disEncModel.state_dict(),
                'optimizer':optimizer.state_dict(),
                'ocr_optimizer':ocr_optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Exemple #24
0
def train(args, loader_src, loader_norm, generator, discriminator, ExpertModel,
          g_optim, d_optim, g_ema, device):

    # Save Path
    date = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    ImgSavePath = 'sample/{}'.format(date)
    CheckpointSavePath = 'checkpoint/{}'.format(date)
    if not os.path.exists(ImgSavePath): os.makedirs(ImgSavePath)
    if not os.path.exists(CheckpointSavePath): os.makedirs(CheckpointSavePath)
    shutil.copy('./train.py', './{}/train.py'.format(CheckpointSavePath))
    shutil.copy('./model.py', './{}/model.py'.format(CheckpointSavePath))

    loader_src = sample_data(loader_src)
    loader_norm = sample_data(loader_norm)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    r1_loss = torch.tensor(0.0, device=device)
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader_src)  # source set
        tgt_img = next(loader_norm)  # normal set
        real_img = real_img.to(device)
        tgt_img = tgt_img.to(device)

        #################################### Train discrimiantor ####################################
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        Profile_Fea, Profile_Map = ExpertModel(
            TrainingSize_Select(real_img, device, args), args)
        Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map)
        Front_Fea, Front_Map = ExpertModel(
            TrainingSize_Select(tgt_img, device, args), args)
        Front_Syn_Img, _ = generator(Front_Fea, Front_Map)

        Profile_Syn_Pred = discriminator(Profile_Syn_Img)
        Front_Syn_Pred = discriminator(Front_Syn_Img)
        Real_Pred = discriminator(tgt_img)
        d_loss = (d_logistic_loss(Real_Pred, Profile_Syn_Pred) +
                  d_logistic_loss(Real_Pred, Front_Syn_Pred)) / 2

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = Real_Pred.mean()
        loss_dict["profile_fake_score"] = Profile_Syn_Pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()

        loss_dict["r1"] = r1_loss

        #################################### Train generator ####################################
        requires_grad(generator, True)
        requires_grad(discriminator, False)

        Front_Fea, Front_Map = ExpertModel(
            TrainingSize_Select(tgt_img, device, args), args)
        Front_Syn_Img, _ = generator(Front_Fea, Front_Map)
        Front_Syn_Pred = discriminator(Front_Syn_Img)
        Front_Syn_Fea, _ = ExpertModel(
            TrainingSize_Select(Front_Syn_Img, device, args), args)

        Profile_Fea, Profile_Map = ExpertModel(
            TrainingSize_Select(real_img, device, args), args)
        Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map)
        Profile_Syn_Pred = discriminator(Profile_Syn_Img)
        Profile_Syn_Fea, _ = ExpertModel(
            TrainingSize_Select(Profile_Syn_Img, device, args), args)

        adv_g_loss = (g_nonsaturating_loss(Profile_Syn_Pred) +
                      g_nonsaturating_loss(Front_Syn_Pred)) / 2
        fea_loss = (feature_loss(Profile_Syn_Fea[0], Profile_Fea[0]) +
                    feature_loss(Front_Syn_Fea[0], Front_Fea[0])) / 2
        sym_loss = (SymLoss(Front_Syn_Img) + SymLoss(Profile_Syn_Img)) / 2
        L1_loss = L1Loss(Front_Syn_Img, tgt_img)
        g_loss = args.lambda_adv * adv_g_loss + args.lambda_fea * fea_loss + args.lambda_sym * sym_loss + args.lambda_l1 * L1_loss

        loss_dict["g"] = g_loss
        loss_dict["adv_g_loss"] = args.lambda_adv * adv_g_loss
        loss_dict["fea_loss"] = args.lambda_fea * fea_loss
        loss_dict["symmetry_loss"] = args.lambda_sym * sym_loss
        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            noise, noise_map = ExpertModel(
                TrainingSize_Select(real_img, device, args), args)
            fake_img, latents = generator(noise,
                                          noise_map,
                                          return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        fea_loss_val = loss_reduced["fea_loss"].mean().item()
        sym_loss_val = loss_reduced["symmetry_loss"].mean().item()

        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        profile_fake_score_val = loss_reduced["profile_fake_score"].mean(
        ).item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g_total: {g_loss_val:.4f}; fea: {fea_loss_val:.4f}; sym: {sym_loss_val:.4f}; r1: {r1_val:.4f};"
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}"
            ))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Profile Score": profile_fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    pro_fea, pro_map = ExpertModel(
                        TrainingSize_Select(real_img, device, args), args)
                    pro_syn, _ = g_ema(pro_fea, pro_map)
                    tgt_fea, tgt_map = ExpertModel(
                        TrainingSize_Select(tgt_img, device, args), args)
                    tgt_syn, _ = g_ema(tgt_fea, tgt_map)

                    result = torch.cat([real_img, pro_syn, tgt_img, tgt_syn],
                                       2)
                    utils.save_image(
                        result,
                        f"{ImgSavePath}/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 100 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"{CheckpointSavePath}/{str(i).zfill(6)}.pt",
                )
def train(args, loader, encoder, generator, discriminator, discriminator_z, g1,
          vggnet, pwcnet, e_optim, d_optim, dz_optim, g1_optim, e_ema, e_tf,
          g1_ema, device):
    mmd_eval = functools.partial(mix_rbf_mmd2,
                                 sigma_list=[2.0, 5.0, 10.0, 20.0, 40.0, 80.0])

    loader = sample_data(loader)
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {
        "d": torch.tensor(0., device=device),
        "real_score": torch.tensor(0., device=device),
        "fake_score": torch.tensor(0., device=device),
        "r1_d": torch.tensor(0., device=device),
        "r1_e": torch.tensor(0., device=device),
        "rec": torch.tensor(0., device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
        g1_module = g1.module if args.train_latent_mlp else None
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator
        g1_module = g1 if args.train_latent_mlp else None

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    # sample_x = accumulate_batches(loader, args.n_sample).to(device)
    sample_x = load_real_samples(args, loader)

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        batch = real_img.shape[0]

        # Train Encoder
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0.,
                                                                 device=device)
        kld_z = torch.tensor(0., device=device)
        mmd_z = torch.tensor(0., device=device)
        gan_z = torch.tensor(0., device=device)
        etf_z = torch.tensor(0., device=device)
        latent_real, logvar = encoder(real_img)
        if args.reparameterization:
            latent_real = reparameterize(latent_real, logvar)
        if args.train_latent_mlp:
            fake_img, _ = generator([g1(latent_real)],
                                    input_is_latent=True,
                                    return_latents=False)
        else:
            fake_img, _ = generator([latent_real],
                                    input_is_latent=False,
                                    return_latents=False)

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            pix_loss = torch.mean((real_img - fake_img)**2)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat)**2)

        if args.lambda_kld_z > 0:
            z_mean = latent_real.view(batch, -1)
            kld_z = -0.5 * torch.sum(1. + logvar - z_mean.pow(2) -
                                     logvar.exp()) / batch
            # print(kld_z)

        if args.lambda_mmd_z > 0:
            z_real = torch.randn(batch, args.latent_full, device=device)
            mmd_z = mmd_eval(latent_real, z_real)
            # print(mmd_z)

        if args.lambda_gan_z > 0:
            fake_pred = discriminator_z(latent_real)
            gan_z = g_nonsaturating_loss(fake_pred)
            # print(gan_z)

        if args.use_latent_teacher_forcing and args.lambda_etf > 0:
            w_tf, _ = e_tf(real_img)
            if args.train_latent_mlp:
                w_pred = g1(latent_real)
            else:
                w_pred = generator.get_latent(latent_real)
            etf_z = torch.mean((w_tf - w_pred)**2)
            # print(etf_z)

        if args.train_on_fake and args.lambda_rec > 0:
            z_real = torch.randn(args.batch, args.latent_full, device=device)
            if args.train_latent_mlp:
                fake_img, _ = generator([g1(z_real)],
                                        input_is_latent=True,
                                        return_latents=False)
            else:
                fake_img, _ = generator([z_real],
                                        input_is_latent=False,
                                        return_latents=False)
            # fake_img, _ = generator([z_real], input_is_latent=False, return_latents=True)
            z_fake, z_logvar = encoder(fake_img)
            if args.reparameterization:
                z_fake = reparameterize(z_fake, z_logvar)
            rec_loss = torch.mean((z_real - z_fake)**2)
            loss_dict["rec"] = rec_loss
            # print(rec_loss)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        e_loss = e_loss + args.lambda_kld_z * kld_z + args.lambda_mmd_z * mmd_z + args.lambda_gan_z * gan_z + args.lambda_etf * etf_z + rec_loss * args.lambda_rec

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        if args.train_latent_mlp and g1 is not None:
            g1.zero_grad()
        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()
        if args.train_latent_mlp and g1_optim is not None:
            g1_optim.step()

        # if args.train_on_fake:
        #     e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
        #     if e_regularize and args.lambda_rec > 0:
        #         # noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        #         # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True)
        #         z_real = torch.randn(args.batch, args.latent_full, device=device)
        #         fake_img, w_real = generator([z_real], input_is_latent=False, return_latents=True)
        #         z_fake, logvar = encoder(fake_img)
        #         if args.reparameterization:
        #             z_fake = reparameterize(z_fake, logvar)
        #         rec_loss = torch.mean((z_real - z_fake) ** 2)
        #         encoder.zero_grad()
        #         (rec_loss * args.lambda_rec).backward()
        #         e_optim.step()
        #         loss_dict["rec"] = rec_loss

        e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        if e_regularize:
            # why not regularize on augmented real?
            real_img.requires_grad = True
            real_pred, logvar = encoder(real_img)
            if args.reparameterization:
                real_pred = reparameterize(real_pred, logvar)
            r1_loss_e = d_r1_loss(real_pred, real_img)

            encoder.zero_grad()
            (args.r1 / 2 * r1_loss_e * args.e_reg_every +
             0 * real_pred.view(-1)[0]).backward()
            e_optim.step()

            loss_dict["r1_e"] = r1_loss_e

        if not args.no_ema and e_ema is not None:
            accumulate(e_ema, e_module, accum)
        if args.train_latent_mlp:
            accumulate(g1_ema, g1_module, accum)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        if not args.no_update_discriminator and args.lambda_adv > 0:
            latent_real, logvar = encoder(real_img)
            if args.reparameterization:
                latent_real = reparameterize(latent_real, logvar)
            if args.train_latent_mlp:
                fake_img, _ = generator([g1(latent_real)],
                                        input_is_latent=True,
                                        return_latents=False)
            else:
                fake_img, _ = generator([latent_real],
                                        input_is_latent=False,
                                        return_latents=False)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img

            fake_pred = discriminator(fake_img_aug)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            z_real = torch.randn(batch, args.latent_full, device=device)
            fake_pred = discriminator_z(latent_real.detach())
            real_pred = discriminator_z(z_real)
            d_loss_z = d_logistic_loss(real_pred, fake_pred)
            discriminator_z.zero_grad()
            d_loss_z.backward()
            dz_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every +
                 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    if args.train_latent_mlp:
                        g1_ema.eval()
                        fake_x, _ = generator([g1_ema(latent_x)],
                                              input_is_latent=True,
                                              return_latents=False)
                    else:
                        fake_x, _ = generator([latent_x],
                                              input_is_latent=False,
                                              return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Encoder": e_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1 D": r1_d_val,
                    "R1 E": r1_e_val,
                    "Pix Loss": pix_loss_val,
                    "VGG Loss": vgg_loss_val,
                    "Adv Loss": adv_loss_val,
                    "Rec Loss": rec_loss_val,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_eval(sample_x)
                    if args.train_latent_mlp:
                        g1_ema.eval()
                        fake_img, _ = generator([g1_ema(latent_real)],
                                                input_is_latent=True,
                                                return_latents=False)
                    else:
                        fake_img, _ = generator([latent_real],
                                                input_is_latent=False,
                                                return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e":
                        e_module.state_dict(),
                        "d":
                        d_module.state_dict(),
                        "g1":
                        g1_module.state_dict()
                        if args.train_latent_mlp else None,
                        "g1_ema":
                        g1_ema.state_dict() if args.train_latent_mlp else None,
                        "g_ema":
                        g_module.state_dict(),
                        "e_ema":
                        e_eval.state_dict(),
                        "e_optim":
                        e_optim.state_dict(),
                        "d_optim":
                        d_optim.state_dict(),
                        "args":
                        args,
                        "ada_aug_p":
                        ada_aug_p,
                        "iter":
                        i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e":
                        e_module.state_dict(),
                        "d":
                        d_module.state_dict(),
                        "g1":
                        g1_module.state_dict()
                        if args.train_latent_mlp else None,
                        "g1_ema":
                        g1_ema.state_dict() if args.train_latent_mlp else None,
                        "g_ema":
                        g_module.state_dict(),
                        "e_ema":
                        e_eval.state_dict(),
                        "e_optim":
                        e_optim.state_dict(),
                        "d_optim":
                        d_optim.state_dict(),
                        "args":
                        args,
                        "ada_aug_p":
                        ada_aug_p,
                        "iter":
                        i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #26
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    start_iter = args.start_iter // get_world_size() // args.batch
    pbar = range(args.iter // get_world_size() // args.batch)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    seg_loss = torch.tensor(0.0, device=device)
    r1_loss = torch.tensor(0.0, device=device)
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg, seg_loss_val, shift_loss_val = 0, 0, 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))

    sample_condition_img, sample_conditions, condition_img_color = random_condition_img(
        args.n_sample)

    if get_rank() == 0:
        os.makedirs(f'sample', exist_ok=True)
        os.makedirs(f'sample/{args.name}', exist_ok=True)
        os.makedirs(f'ckpts/{args.name}', exist_ok=True)
        if args.with_tensorboard:
            os.makedirs(f'tensorboard/{args.name}', exist_ok=True)
            writer = SummaryWriter(f'tensorboard/{args.name}')

    for idx in pbar:
        i = idx + start_iter

        if i > args.iter:
            print('Done!')
            break

        real_img, condition_img = next(loader)
        real_img = real_img.to(device)

        if args.condition_path is not None:
            condition_img = condition_img.to(device)
        else:
            condition_img = None

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _, _, _ = generator(noise, condition_img=condition_img)
        if args.with_rgbs:
            condition_img_encoder = F.interpolate(condition_img,
                                                  size=args.resolution,
                                                  mode='nearest')
            real_img = torch.cat((real_img, condition_img_encoder), dim=1)
        fake_pred, _ = discriminator(fake_img)
        real_pred, real_pred_feat = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred, _ = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _, _, parsing_feature = generator(
            noise, condition_img=condition_img)

        fake_pred, fake_pred_feat = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)
        loss_dict['g'] = g_loss

        loss_dict['seg'] = seg_loss
        loss_dict['shift_loss'] = seg_loss

        loss = g_loss
        generator.zero_grad()
        loss.backward()
        g_optim.step()

        requires_grad(generator, True)
        requires_grad(discriminator, False)
        g_regularize = i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            if args.condition_path is not None:
                condition_img = condition_img[range(path_batch_size)]
                condition_img.requires_grad = True

            fake_img, latents, _, _ = generator(noise,
                                                return_latents=True,
                                                condition_img=condition_img)

            path_loss, mean_path_length, path_lengths, isNaN = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.g_reg_every * args.path_regularize * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()
            if not isNaN:
                g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)
        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced['d'].mean().item()
        g_loss_val = loss_reduced['g'].mean().item()
        r1_val = loss_reduced['r1'].mean().item()
        path_length_val = loss_reduced['path_length'].mean().item()

        if args.condition_path is not None and (0 == i % args.g_reg_every):
            seg_loss_val = loss_reduced['seg'].mean().item()
            shift_loss_val = loss_reduced['shift_loss'].mean().item()

        if get_rank() == 0:
            pbar.set_description((f'mean path: {mean_path_length_avg:.4f}'))

            if args.with_tensorboard:
                writer.add_scalar('Loss/Generator', g_loss_val, i)
                writer.add_scalar('Loss/Discriminator', d_loss_val, i)
                writer.add_scalar('Loss/R1', r1_val, i)
                writer.add_scalar('Loss/Path Length', path_length_val, i)
                writer.add_scalar('Loss/mean path', mean_path_length_avg, i)

                if args.condition_path is not None:
                    writer.add_scalar('Loss/seg_img', seg_loss_val, i)
                    writer.add_scalar('Loss/shift_loss', shift_loss_val, i)

            steps = get_world_size() * args.batch * (1 + i)
            if steps % 100000 < get_world_size() * args.batch or (
                    steps < 1000
                    and steps % 500 == get_world_size() * args.batch):
                with torch.no_grad():
                    g_ema.eval()
                    samples, featuresMaps, parsing_features = [], [], []
                    small_batch = args.n_sample // args.batch
                    if 0 != args.n_sample % args.batch:
                        small_batch += 1

                    # only condition change
                    rows = int(args.n_sample**0.5)
                    if args.condition_path is not None:
                        sample_z = mixing_noise(rows, args.latent, args.mixing,
                                                device)
                        sample_z = sample_z.unsqueeze(1).repeat(
                            1, rows, 1, 1).view(args.n_sample,
                                                sample_z.shape[1],
                                                sample_z.shape[2])
                    else:
                        sample_z = mixing_noise(args.n_sample, args.latent,
                                                args.mixing, device)

                    for k in range(small_batch):

                        start, end = k * args.batch, (k + 1) * args.batch
                        if k == small_batch - 1:
                            end = sample_z.shape[0]

                        if args.condition_path is not None:
                            sample_condition_img_sub = sample_condition_img[
                                start:end]
                            sample_condition_img_sub = random_affine(
                                sample_condition_img_sub.clone(),
                                Scale=0.0).to(device)
                        else:
                            sample_condition_img_sub = None

                        sample, _, _, _ = g_ema(
                            sample_z[start:end],
                            condition_img=sample_condition_img_sub)
                        samples.append(sample.cpu().detach())

                    samples = torch.cat(samples, dim=0)

                    nrow = int(args.n_sample**0.5)
                    c, h, w = samples.shape[-3:]
                    samples = samples.reshape(nrow, nrow, c, h, w).transpose(
                        1, 0).reshape(-1, c, h, w)
                    utils.save_image(
                        samples,
                        f'sample/{args.name}/{str(steps).zfill(6)}.png',
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )
                    if 0 == i:
                        c, h, w = condition_img_color.shape[-3:]
                        condition_img_color = condition_img_color.reshape(
                            nrow, nrow, c, h,
                            w).transpose(1, 0).reshape(-1, c, h, w)
                        utils.save_image(
                            condition_img_color,
                            f'sample/{args.name}/seg_vis.png',
                            nrow=nrow,
                            normalize=True,
                            range=(-1, 1),
                        )

            if (steps +
                    get_world_size() * args.batch) % 100000 < get_world_size(
                    ) * args.batch and steps != args.start_iter:
                torch.save(
                    {
                        'g': g_module.state_dict(),
                        'd': d_module.state_dict(),
                        'g_ema': g_ema.state_dict(),
                        # 'g_optim': g_optim.state_dict(),
                        # 'd_optim': d_optim.state_dict(),
                    },
                    f'ckpts/{args.name}/{str(steps).zfill(6)}.pt',
                )
Exemple #27
0
def train(args, loader, generator, discriminator, contrast_learner, augment, g_optim, d_optim, scaler, g_ema, device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = th.zeros(size=(1,), device=device)
    g_loss_val = 0
    path_loss = th.zeros(size=(1,), device=device)
    path_lengths = th.zeros(size=(1,), device=device)
    loss_dict = {}
    mse = th.nn.MSELoss()

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        if contrast_learner is not None:
            cl_module = contrast_learner.module
    else:
        g_module = generator
        d_module = discriminator
        cl_module = contrast_learner

    sample_z = th.randn(args.n_sample, args.latent_size, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        discriminator.zero_grad()

        loss_dict["d"], loss_dict["real_score"], loss_dict["fake_score"] = 0, 0, 0
        loss_dict["cl_reg"], loss_dict["bc_reg"] = (
            th.tensor(0, device=device).float(),
            th.tensor(0, device=device).float(),
        )
        for _ in range(args.num_accumulate):
            # sample = []
            # for _ in range(0, len(sample_z), args.batch_size):
            #     subsample = next(loader)
            #     sample.append(subsample)
            # sample = th.cat(sample)
            # utils.save_image(sample, "reals-no-augment.png", nrow=10, normalize=True)
            # utils.save_image(augment(sample), "reals-augment.png", nrow=10, normalize=True)

            real_img = next(loader)
            real_img = real_img.to(device)

            # with th.cuda.amp.autocast():
            noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device)
            fake_img, _ = generator(noise)

            if args.augment_D:
                fake_pred = discriminator(augment(fake_img))
                real_pred = discriminator(augment(real_img))
            else:
                fake_pred = discriminator(fake_img)
                real_pred = discriminator(real_img)

            # logistic loss
            real_loss = F.softplus(-real_pred)
            fake_loss = F.softplus(fake_pred)
            d_loss = real_loss.mean() + fake_loss.mean()

            loss_dict["d"] += d_loss.detach()
            loss_dict["real_score"] += real_pred.mean().detach()
            loss_dict["fake_score"] += fake_pred.mean().detach()

            if i > 10000 or i == 0:
                if args.contrastive > 0:
                    contrast_learner(fake_img.clone().detach(), accumulate=True)
                    contrast_learner(real_img, accumulate=True)

                    contrast_loss = cl_module.calculate_loss()
                    loss_dict["cl_reg"] += contrast_loss.detach()

                    d_loss += args.contrastive * contrast_loss

                if args.balanced_consistency > 0:
                    aug_fake_pred = discriminator(augment(fake_img.clone().detach()))
                    aug_real_pred = discriminator(augment(real_img))

                    consistency_loss = mse(real_pred, aug_real_pred) + mse(fake_pred, aug_fake_pred)
                    loss_dict["bc_reg"] += consistency_loss.detach()

                    d_loss += args.balanced_consistency * consistency_loss

            d_loss /= args.num_accumulate
            # scaler.scale(d_loss).backward()
            d_loss.backward()

        # scaler.step(d_optim)
        d_optim.step()

        # R1 regularization
        if args.r1 > 0 and i % args.d_reg_every == 0:

            discriminator.zero_grad()

            loss_dict["r1"] = 0
            for _ in range(args.num_accumulate):
                real_img = next(loader)
                real_img = real_img.to(device)

                real_img.requires_grad = True

                # with th.cuda.amp.autocast():
                # if args.augment_D:
                #     real_pred = discriminator(
                #         augment(real_img)
                #     )  # RuntimeError: derivative for grid_sampler_2d_backward is not implemented :(
                # else:
                real_pred = discriminator(real_img)
                real_pred_sum = real_pred.sum()

                (grad_real,) = th.autograd.grad(outputs=real_pred_sum, inputs=real_img, create_graph=True)
                # (grad_real,) = th.autograd.grad(outputs=scaler.scale(real_pred_sum), inputs=real_img, create_graph=True)
                # grad_real = grad_real * (1.0 / scaler.get_scale())

                # with th.cuda.amp.autocast():
                r1_loss = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
                weighted_r1_loss = args.r1 / 2.0 * r1_loss * args.d_reg_every + 0 * real_pred[0]

                loss_dict["r1"] += r1_loss.detach()

                weighted_r1_loss /= args.num_accumulate
                # scaler.scale(weighted_r1_loss).backward()
                weighted_r1_loss.backward()

            # scaler.step(d_optim)
            d_optim.step()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        generator.zero_grad()
        loss_dict["g"] = 0
        for _ in range(args.num_accumulate):
            # with th.cuda.amp.autocast():
            noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device)
            fake_img, _ = generator(noise)

            if args.augment_G:
                fake_img = augment(fake_img)

            fake_pred = discriminator(fake_img)

            # non-saturating loss
            g_loss = F.softplus(-fake_pred).mean()

            loss_dict["g"] += g_loss.detach()

            g_loss /= args.num_accumulate
            # scaler.scale(g_loss).backward()
            g_loss.backward()

        # scaler.step(g_optim)
        g_optim.step()

        # path length regularization
        if args.path_regularize > 0 and i % args.g_reg_every == 0:

            generator.zero_grad()

            loss_dict["path"], loss_dict["path_length"] = 0, 0
            for _ in range(args.num_accumulate):
                path_batch_size = max(1, args.batch_size // args.path_batch_shrink)

                # with th.cuda.amp.autocast():
                noise = make_noise(path_batch_size, args.latent_size, args.mixing_prob, device)
                fake_img, latents = generator(noise, return_latents=True)

                img_noise = th.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
                noisy_img_sum = (fake_img * img_noise).sum()

                (grad,) = th.autograd.grad(outputs=noisy_img_sum, inputs=latents, create_graph=True)
                # (grad,) = th.autograd.grad(outputs=scaler.scale(noisy_img_sum), inputs=latents, create_graph=True)
                # grad = grad * (1.0 / scaler.get_scale())

                # with th.cuda.amp.autocast():
                path_lengths = th.sqrt(grad.pow(2).sum(2).mean(1))
                path_mean = mean_path_length + 0.01 * (path_lengths.mean() - mean_path_length)
                path_loss = (path_lengths - path_mean).pow(2).mean()
                mean_path_length = path_mean.detach()

                loss_dict["path"] += path_loss.detach()
                loss_dict["path_length"] += path_lengths.mean().detach()

                weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
                if args.path_batch_shrink:
                    weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

                weighted_path_loss /= args.num_accumulate
                # scaler.scale(weighted_path_loss).backward()
                weighted_path_loss.backward()

            # scaler.step(g_optim)
            g_optim.step()

        # scaler.update()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item() / args.num_accumulate
        g_loss_val = loss_reduced["g"].mean().item() / args.num_accumulate
        cl_reg_val = loss_reduced["cl_reg"].mean().item() / args.num_accumulate
        bc_reg_val = loss_reduced["bc_reg"].mean().item() / args.num_accumulate
        r1_val = loss_reduced["r1"].mean().item() / args.num_accumulate
        path_loss_val = loss_reduced["path"].mean().item() / args.num_accumulate
        real_score_val = loss_reduced["real_score"].mean().item() / args.num_accumulate
        fake_score_val = loss_reduced["fake_score"].mean().item() / args.num_accumulate
        path_length_val = loss_reduced["path_length"].mean().item() / args.num_accumulate

        if get_rank() == 0:

            log_dict = {
                "Generator": g_loss_val,
                "Discriminator": d_loss_val,
                "Real Score": real_score_val,
                "Fake Score": fake_score_val,
                "Contrastive": cl_reg_val,
                "Consistency": bc_reg_val,
            }

            if args.log_spec_norm:
                G_norms = []
                for name, spec_norm in g_module.named_buffers():
                    if "spectral_norm" in name:
                        G_norms.append(spec_norm.cpu().numpy())
                G_norms = np.array(G_norms)
                D_norms = []
                for name, spec_norm in d_module.named_buffers():
                    if "spectral_norm" in name:
                        D_norms.append(spec_norm.cpu().numpy())
                D_norms = np.array(D_norms)

                log_dict[f"Spectral Norms/G min spectral norm"] = np.log(G_norms).min()
                log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(G_norms).mean()
                log_dict[f"Spectral Norms/G max spectral norm"] = np.log(G_norms).max()
                log_dict[f"Spectral Norms/D min spectral norm"] = np.log(D_norms).min()
                log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(D_norms).mean()
                log_dict[f"Spectral Norms/D max spectral norm"] = np.log(D_norms).max()

            if args.r1 > 0 and i % args.d_reg_every == 0:
                log_dict["R1"] = r1_val

            if args.path_regularize > 0 and i % args.g_reg_every == 0:
                log_dict["Path Length Regularization"] = path_loss_val
                log_dict["Mean Path Length"] = mean_path_length
                log_dict["Path Length"] = path_length_val

            if i % args.img_every == 0:
                gc.collect()
                th.cuda.empty_cache()
                with th.no_grad():
                    g_ema.eval()
                    sample = []
                    for sub in range(0, len(sample_z), args.batch_size):
                        subsample, _ = g_ema([sample_z[sub : sub + args.batch_size]])
                        sample.append(subsample.cpu())
                    sample = th.cat(sample)
                    grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1))
                    # utils.save_image(sample, "fakes-no-augment.png", nrow=10, normalize=True)
                    # utils.save_image(augment(sample), "fakes-augment.png", nrow=10, normalize=True)
                    # exit()
                log_dict["Generated Images EMA"] = [wandb.Image(grid, caption=f"Step {i}")]

            if i % args.eval_every == 0:
                start_time = time.time()
                pbar.set_description((f"Calculating FID..."))
                fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name)
                fid = fid_dict["FID"]
                density = fid_dict["Density"]
                coverage = fid_dict["Coverage"]

                pbar.set_description((f"Calculating PPL..."))
                ppl = validation.ppl(
                    g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size,
                )

                pbar.set_description(
                    (
                        f"FID: {fid:.4f}; Density: {density:.4f}; Coverage: {coverage:.4f}; PPL: {ppl:.4f} in {time.time() - start_time:.1f}s"
                    )
                )
                log_dict["Evaluation/FID"] = fid
                log_dict["Evaluation/Density"] = density
                log_dict["Evaluation/Coverage"] = coverage
                log_dict["Evaluation/PPL"] = ppl

                gc.collect()
                th.cuda.empty_cache()

            wandb.log(log_dict)

            if i % args.checkpoint_every == 0:
                th.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        # "cl": cl_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"/home/hans/modelzoo/maua-sg2/{args.name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}-{int(fid)}-{int(ppl)}-{str(i).zfill(6)}.pt",
                )
def train(
    args,
    loader,
    encoder,
    generator,
    discriminator,
    discriminator3d,  # video disctiminator
    posterior,
    prior,
    factor,  # a learnable matrix
    vggnet,
    e_optim,
    d_optim,
    dv_optim,
    q_optim,  # q for posterior
    p_optim,  # p for prior
    f_optim,  # f for factor
    e_ema,
    device
):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {"d": torch.tensor(0., device=device),
                 "real_score": torch.tensor(0., device=device),
                 "fake_score": torch.tensor(0., device=device),
                 "r1_d": torch.tensor(0., device=device),
                 "r1_e": torch.tensor(0., device=device),
                 "rec": torch.tensor(0., device=device),}

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    latent_full = args.latent_full
    factor_dim_full = args.factor_dim_full

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)

    sample_x = accumulate_batches(loader, args.n_sample).to(device)
    utils.save_image(
        sample_x.view(-1, *list(sample_x.shape)[2:]),
        os.path.join(args.log_dir, 'sample', f"real-img.png"),
        nrow=sample_x.shape[1],
        normalize=True,
        value_range=(-1, 1),
    )
    util.save_video(
        sample_x[0],
        os.path.join(args.log_dir, 'sample', f"real-vid.mp4")
    )

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode
    if args.no_update_encoder:
        encoder = e_ema if e_ema is not None else encoder
        requires_grad(encoder, False)
        encoder.eval()
    from models.networks_3d import GANLoss
    criterionGAN = GANLoss()
    # criterionL1 = nn.L1Loss()

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder
    
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        data = next(loader)
        real_seq = data['frames']
        real_seq = real_seq.to(device)  # [N, T, C, H, W]
        shape = list(real_seq.shape)
        N, T = shape[:2]

        # Train Encoder with frame-level objectives
        if args.toggle_grads:
            if not args.no_update_encoder:
                requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = vid_loss = l1y_loss = torch.tensor(0., device=device)

        # TODO: real_seq -> encoder -> posterior -> generator -> fake_seq
        # f: [N, latent_full]; y: [N, T, D]
        fake_img, fake_seq, y_post = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior, i, ret_y=True)
        # if args.debug == 'no_lstm':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'decomp':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     if args.use_multi_head:
        #         y_post = []
        #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
        #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
        #         y_post = torch.cat(y_post, 2)
        #     else:
        #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
        #     z_post_hat = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post_hat
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # else:
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     # single head: f_post [N, latent_full]; y_post [N, T, D]
        #     # multi head: f_post [N, n_latent, latent]; y_post [N, T, n_latent, d]
        #     f_post, y_post = posterior(real_lat.view(N, T, latent_full))
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post  # shape [N, T, latent_full]
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])

        # TODO: sample frames
        real_img = real_seq.view(N*T, *shape[2:])
        # fake_img = fake_seq.view(N*T, *shape[2:])

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        # TODO: do we always put pix and vgg loss for all frames?
        if args.lambda_pix > 0:
            pix_loss = torch.mean((real_img - fake_img) ** 2)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat) ** 2)
        
        # Train Encoder with video-level objectives
        # TODO: video adversarial loss
        if args.lambda_vid > 0:
            fake_pred = discriminator3d(flip_video(fake_seq.transpose(1, 2)))
            vid_loss = criterionGAN(fake_pred, True)
        
        if args.lambda_l1y > 0:
            # l1y_loss = criterionL1(y_post)
            l1y_loss = torch.mean(torch.abs(y_post))

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        e_loss = e_loss + args.lambda_vid * vid_loss + args.lambda_l1y * l1y_loss
        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss
        
        if not args.no_update_encoder:
            encoder.zero_grad()
        posterior.zero_grad()
        e_loss.backward()
        q_optim.step()
        if not args.no_update_encoder:
            e_optim.step()

        # if args.train_on_fake:
        #     e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
        #     if e_regularize and args.lambda_rec > 0:
        #         noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        #         fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True)
        #         latent_pred = encoder(fake_img)
        #         if latent_pred.ndim < 3:
        #             latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1)
        #         rec_loss = torch.mean((latent_fake - latent_pred) ** 2)
        #         encoder.zero_grad()
        #         (rec_loss * args.lambda_rec).backward()
        #         e_optim.step()
        #         loss_dict["rec"] = rec_loss

        # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        # if e_regularize:
        #     # why not regularize on augmented real?
        #     real_img.requires_grad = True
        #     real_pred = encoder(real_img)
        #     r1_loss_e = d_r1_loss(real_pred, real_img)

        #     encoder.zero_grad()
        #     (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward()
        #     e_optim.step()

        #     loss_dict["r1_e"] = r1_loss_e

        if not args.no_update_encoder:
            if not args.no_ema and e_ema is not None:
                accumulate(e_ema, e_module, accum)
        
        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        fake_img, fake_seq = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior)
        # if args.debug == 'no_lstm':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'decomp':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     if args.use_multi_head:
        #         y_post = []
        #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
        #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
        #         y_post = torch.cat(y_post, 2)
        #     else:
        #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
        #     z_post_hat = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post_hat
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'coef':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post_hat = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     y_post = torch.mm(z_post_hat.view(N*T, -1), factor.weight).view(N, T, -1)
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # else:
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     f_post, y_post = posterior(real_lat.view(N, T, latent_full))
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post  # shape [N, T, latent_full]
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # fake_img = fake_seq.view(N*T, *shape[2:])
        if not args.no_update_discriminator:
            if args.lambda_adv > 0:
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                    fake_img_aug, _ = augment(fake_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                    fake_img_aug = fake_img
                
                fake_pred = discriminator(fake_img_aug)
                real_pred = discriminator(real_img_aug)
                d_loss = d_logistic_loss(real_pred, fake_pred)

            # Train video discriminator
            if args.lambda_vid > 0:
                pred_real = discriminator3d(flip_video(real_seq.transpose(1, 2)))
                pred_fake = discriminator3d(flip_video(fake_seq.transpose(1, 2)))
                dv_loss_real = criterionGAN(pred_real, True)
                dv_loss_fake = criterionGAN(pred_fake, False)
                dv_loss = 0.5 * (dv_loss_real + dv_loss_fake)
                d_loss = d_loss + dv_loss

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            if args.lambda_adv > 0:
                discriminator.zero_grad()
            if args.lambda_vid > 0:
                discriminator3d.zero_grad()
            d_loss.backward()
            if args.lambda_adv > 0:
                d_optim.step()
            if args.lambda_vid > 0:
                dv_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat
            
            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                    f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                    f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Encoder": e_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1 D": r1_d_val,
                        "R1 E": r1_e_val,
                        "Pix Loss": pix_loss_val,
                        "VGG Loss": vgg_loss_val,
                        "Adv Loss": adv_loss_val,
                        "Rec Loss": rec_loss_val,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                    }
                )

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    posterior.eval()
                    # N = sample_x.shape[0]
                    fake_img, fake_seq = reconstruct_sequence(args, sample_x, e_eval, generator, factor, posterior)
                    # if args.debug == 'no_lstm':
                    #     real_lat = encoder(sample_x.view(-1, *shape[2:]))
                    #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    # elif args.debug == 'decomp':
                    #     real_lat = encoder(sample_x.view(-1, *shape[2:]))  # [N*T, latent_full]
                    #     f_post = real_lat[::T, ...]
                    #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
                    #     if args.use_multi_head:
                    #         y_post = []
                    #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
                    #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
                    #         y_post = torch.cat(y_post, 2)
                    #     else:
                    #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
                    #     z_post_hat = factor(y_post)
                    #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
                    #     w_post = f_expand + z_post_hat
                    #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    # else:
                    #     x_lat = encoder(sample_x.view(-1, *shape[2:]))
                    #     f_post, y_post = posterior(x_lat.view(N, T, latent_full))
                    #     z_post = factor(y_post)
                    #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
                    #     w_post = f_expand + z_post
                    #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    utils.save_image(
                        torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_recon.png"),
                        nrow=T,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    util.save_video(
                        fake_seq[random.randint(0, args.n_sample-1)],
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-vid_recon.mp4")
                    )
                    fake_img, fake_seq = swap_sequence(args, sample_x, e_eval, generator, factor, posterior)
                    utils.save_image(
                        torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_swap.png"),
                        nrow=T,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()
                    posterior.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if not args.debug and i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #29
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    none_g_grads = set()
    test_in = torch.randn(1, args.latent, device=device)
    fake, latent = g_module([test_in], return_latents=True)
    path = g_path_regularize(fake, latent, 0)
    path[0].backward()

    for n, p in generator.named_parameters():
        if p.grad is None:
            none_g_grads.add(n)

    test_in = torch.randn(1, 3, args.size, args.size, requires_grad=True, device=device)
    pred = d_module(test_in)
    r1_loss = d_r1_loss(pred, test_in)
    r1_loss.backward()

    none_d_grads = set()
    for n, p in discriminator.named_parameters():
        if p.grad is None:
            none_d_grads.add(n)

    sample_z = torch.randn(2 * 2, args.latent, device=device)

    for i in pbar:
        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)

        real_pred = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
            set_grad_none(discriminator, none_d_grads)

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator.proj, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        noise_proj_loss = sum([(generator.proj(noise_i) - noise_i).abs().sum() for noise_i in noise])
        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        print(noise_proj_loss.item())

        loss_dict['g'] = g_loss

        generator.zero_grad()
        (g_loss + noise_proj_loss).backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            noise = mixing_noise(
                args.batch // args.path_batch_shrink, args.latent, args.mixing, device
            )
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()
            set_grad_none(g_module, none_g_grads)

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced['d'].mean().item()
        g_loss_val = loss_reduced['g'].mean().item()
        r1_val = loss_reduced['r1'].mean().item()
        path_loss_val = loss_reduced['path'].mean().item()
        real_score_val = loss_reduced['real_score'].mean().item()
        fake_score_val = loss_reduced['fake_score'].mean().item()
        path_length_val = loss_reduced['path_length'].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; '
                    f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}'
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        'Generator': g_loss_val,
                        'Discriminator': d_loss_val,
                        'R1': r1_val,
                        'Path Length Regularization': path_loss_val,
                        'Mean Path Length': mean_path_length,
                        'Real Score': real_score_val,
                        'Fake Score': fake_score_val,
                        'Path Length': path_length_val,
                    }
                )

            if i % 10000 == 0:
                torch.save(
                    {
                        'g': g_module.state_dict(),
                        'd': d_module.state_dict(),
                        'g_ema': g_ema.state_dict(),
                        'g_optim': g_optim.state_dict(),
                        'd_optim': d_optim.state_dict(),
                    },
                    f'checkpoint/{str(i).zfill(6)}.pt',
                )

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        f'sample/{str(i).zfill(6)}.png',
                        nrow=2,
                        normalize=True,
                        range=(-1, 1),
                    )
def train(args, loader, generator, encoder, discriminator, vggnet, g_optim,
          e_optim, d_optim, g_ema, e_ema, device):
    kwargs_d = {'detach_aux': False}
    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)

    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    else:
        inception = real_mean = real_cov = None
    mean_latent = None

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        real_pred = discriminator(real_img)
        fake_pred = discriminator(fake_img)
        rec_pred = discriminator(rec_img)
        d_loss_real = F.softplus(-real_pred).mean()
        d_loss_fake = F.softplus(fake_pred).mean()
        d_loss_rec = F.softplus(rec_pred).mean()
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()
        loss_dict["rec_score"] = rec_pred.mean()

        d_loss = d_loss_real + d_loss_fake + d_loss_rec
        loss_dict["d"] = d_loss

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # # Train Encoder and Generator
        # requires_grad(generator, True)
        # requires_grad(encoder, True)
        # requires_grad(discriminator, False)
        # pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        # noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        # fake_img, _ = generator(noise)
        # latent_real, _ = encoder(real_img)
        # rec_img, _ = generator([latent_real], input_is_latent=True)
        # fake_pred = discriminator(fake_img)
        # rec_pred = discriminator(rec_img)
        # g_loss_fake = g_nonsaturating_loss(fake_pred)
        # g_loss_rec = g_nonsaturating_loss(rec_pred)
        # adv_loss = g_loss_fake + g_loss_rec
        # if args.lambda_pix > 0:
        #     if args.pix_loss == 'l2':
        #         pix_loss = torch.mean((rec_img - real_img) ** 2)
        #     else:
        #         pix_loss = F.l1_loss(rec_img, real_img)
        # if args.lambda_vgg > 0:
        #     vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img)) ** 2)
        # e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        # loss_dict["e"] = e_loss
        # encoder.zero_grad()
        # generator.zero_grad()
        # e_loss.backward()
        # e_optim.step()
        # g_optim.step()

        # Train Encoder
        requires_grad(generator, False)
        requires_grad(encoder, True)
        requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        rec_pred = discriminator(rec_img)
        g_loss_rec = g_nonsaturating_loss(rec_pred)
        adv_loss = g_loss_rec
        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
        if args.lambda_vgg > 0:
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["e"] = e_loss
        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        fake_pred = discriminator(fake_img)
        rec_pred = discriminator(rec_img)
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        g_loss_rec = g_nonsaturating_loss(rec_pred)
        adv_loss = g_loss_fake + g_loss_rec
        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
        if args.lambda_vgg > 0:
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

        g_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["g"] = g_loss
        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        with torch.no_grad():
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            pix_loss_val = pix_loss.mean().item()
            vgg_loss_val = vgg_loss.mean().item()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean,
                                   real_cov)
                print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}: fid: {float(fid):.4f}\n")

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    # Fixed fake samples
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )