Пример #1
0
    def forward(self, imgs, sentence=None, word_emb=None, train_perm=True):
        feat_64 = self.from_64(imgs[0])
        #feat_128 = self.from_128(imgs[1])

        logit_1 = self.rf_64(feat_64).view(-1)
        #logit_3 = self.rf_128(feat_128).view(-1)

        b = imgs[0].shape[0]
        #img_feat = torch.cat([feat_64, feat_128], dim=1)
        img_feat = feat_64

        match = match_perm = match_word = match_perm_word = None

        #match = self.it_match(img_feat, sentence)
        #match = self.matcher_2(torch.cat([self.matcher_1(img_feat), sentence.unsqueeze(-1).unsqueeze(-1)],dim=1).squeeze(-1).squeeze(-1)) #b,256
        match_word = self.wordLevelDis(img_feat, word_emb)

        if train_perm:
            perm = true_randperm(b)
            #match_perm = self.it_match(img_feat[perm], sentence)
            #match_perm = self.matcher_2(torch.cat([self.matcher_1(img_feat[perm]), sentence.unsqueeze(-1).unsqueeze(-1)],dim=1).squeeze(-1).squeeze(-1)) #b,256
            match_perm_word = self.wordLevelDis(img_feat[perm], word_emb)

        return torch.cat([logit_1
                          ]), match, match_perm, match_word, match_perm_word
Пример #2
0
    def forward(self, imgs, sentence=None, word_emb=None, train_perm=True):
        feat_64 = self.from_64(imgs[0])
        feat_256 = self.from_256(imgs[1])

        logit_1 = self.rf_64(feat_64).view(-1)
        logit_3 = self.rf_256(feat_256).view(-1)

        b = imgs[0].shape[0]
        img_feat = torch.cat([feat_64, feat_256], dim=1)

        match = pred_text = match_perm = pred_text_perm = None

        #match = self.it_match(img_feat, sentence)
        match = self.matcher(
            torch.cat([
                img_feat,
                sentence.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 8, 8)
            ],
                      dim=1))  #b,256

        if train_perm:
            perm = true_randperm(b)
            #match_perm = self.it_match(img_feat[perm], sentence)
            match_perm = self.matcher(
                torch.cat([
                    img_feat[perm],
                    sentence.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 8, 8)
                ],
                          dim=1))  #b,256

        return torch.cat([logit_1, logit_3]), match, match_perm
def loss_for_list_perm(loss, fl1, fl2, detach_second=True):
    result_loss = 0
    for f_idx in range(len(fl1)):
        perm = true_randperm(fl1[0].shape[0], fl1[0].device)
        if detach_second:
            result_loss += F.relu(2 + loss(fl1[f_idx], fl2[f_idx].detach()) -
                                  loss(fl1[f_idx][perm], fl2[f_idx].detach()))
        else:
            result_loss += F.relu(2 + loss(fl1[f_idx], fl2[f_idx]) -
                                  loss(fl1[f_idx][perm], fl2[f_idx]))
    return result_loss
Пример #4
0
def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500):
    counter = 0
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE,
                   shuffle=False,
                   num_workers=4,
                   pin_memory=False))
    n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1)
    while counter < n_batches:
        counter += 1
        rgb_img, _, _, skt_img = next(dataloader)
        rgb_img = F.interpolate(rgb_img, size=512).cuda()
        skt_img = F.interpolate(skt_img, size=512).cuda()

        perm = true_randperm(rgb_img.shape[0], device=rgb_img.device)

        gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm])
        g_image = net_ig(gimg_ae, style_feat)
        if counter == 1:
            vutils.save_image(0.5 * (g_image + 1), 'tmp.jpg')
        yield g_image
Пример #5
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
def train():
    from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL
    from config import DATA_NAME
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_AE,
                                 rand_crop=True)
    print(len(dataset))
    dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))

    dataset_ss = SelfSupervisedDataset(data_root_colorful,
                                       data_root_sketch_3,
                                       im_size=IM_SIZE_AE,
                                       nbr_cls=NBR_CLS,
                                       rand_crop=True)
    print(len(dataset_ss), len(dataset_ss.frame))
    dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))

    style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda()
    content_encoder = ContentEncoder(nfc=NFC).cuda()
    decoder = Decoder(nfc=NFC).cuda()

    opt_c = optim.Adam(content_encoder.parameters(),
                       lr=2e-4,
                       betas=(0.5, 0.999))
    opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()

    from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER
    if PRETRAINED_AE_PATH is not None:
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER
        ckpt = torch.load(PRETRAINED_AE_PATH)

        print(PRETRAINED_AE_PATH)

        style_encoder.load_state_dict(ckpt['s'])
        content_encoder.load_state_dict(ckpt['c'])
        decoder.load_state_dict(ckpt['d'])

        opt_c.load_state_dict(ckpt['opt_c'])
        opt_s.load_state_dict(ckpt['opt_s'])
        opt_d.load_state_dict(ckpt['opt_d'])
        print('loaded pre-trained AE')

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()
    opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                           lr=2e-4,
                           betas=(0.5, 0.999))

    saved_image_folder, saved_model_folder = make_folders(
        SAVE_FOLDER, 'AE_' + TRIAL_NAME)
    log_file_path = saved_image_folder + '/../ae_log.txt'
    log_file = open(log_file_path, 'w')
    log_file.close()
    ## for logging
    losses_sf_consist = AverageMeter()
    losses_cf_consist = AverageMeter()
    losses_cls = AverageMeter()
    losses_rec_rd = AverageMeter()
    losses_rec_org = AverageMeter()
    losses_rec_grey = AverageMeter()

    import lpips
    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    for iteration in tqdm(range(ITERATION_AE)):

        if iteration % (
            (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1:
            dataset_ss._next_set()
            dataloader_ss = iter(
                DataLoader(dataset_ss,
                           BATCH_SIZE_AE,
                           sampler=InfiniteSamplerWrapper(dataset_ss),
                           num_workers=DATALOADER_WORKERS,
                           pin_memory=True))
            style_encoder.reset_cls()
            opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                                   lr=2e-4,
                                   betas=(0.5, 0.999))

            opt_s.param_groups[0]['lr'] = 1e-4
            opt_d.param_groups[0]['lr'] = 1e-4

        ### 1. train the encoder with self-supervision methods
        rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next(
            dataloader_ss)
        rgb_img_rd = rgb_img_rd.cuda()
        rgb_img_org = rgb_img_org.cuda()
        img_idx = img_idx.cuda()

        skt_org = F.interpolate(skt_org, size=512).cuda()
        skt_bold = F.interpolate(skt_bold, size=512).cuda()
        skt_erased = F.interpolate(skt_erased, size=512).cuda()
        skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
        style_vector_org, pred_cls_org = style_encoder(rgb_img_org)

        content_feats = content_encoder(skt_org)
        content_feats_bold = content_encoder(skt_bold)
        content_feats_erased = content_encoder(skt_erased)
        content_feats_eb = content_encoder(skt_erased_bold)

        rd = random.randint(0, 3)
        gimg_rd = None
        if rd == 0:
            gimg_rd = decoder(content_feats, style_vector_rd)
        elif rd == 1:
            gimg_rd = decoder(content_feats_bold, style_vector_rd)
        elif rd == 2:
            gimg_rd = decoder(content_feats_erased, style_vector_rd)
        elif rd == 3:
            gimg_rd = decoder(content_feats_eb, style_vector_rd)


        loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\
                            loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\
                                loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats)

        loss_sf_consist = 0
        for loss_idx in range(3):
            loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \
                                    F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean()

        loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy(
            pred_cls_org, img_idx)
        loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org)
        if DATA_NAME != 'shoe':
            loss_rec_rd += percept(
                F.adaptive_avg_pool2d(gimg_rd, output_size=256),
                F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
        else:
            loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org)

        loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist  #+ loss_kl_c + loss_kl_s
        loss_total.backward()

        opt_s.step()
        opt_s_cls.step()
        opt_c.step()
        opt_d.step()

        ### 2. train as AutoEncoder
        rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

        rgb_img = rgb_img.cuda()

        rd = random.randint(0, 3)
        if rd == 0:
            skt_img = skt_img_1
        elif rd == 1:
            skt_img = skt_img_2
        else:
            skt_img = skt_img_3

        skt_img = F.interpolate(skt_img, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector, _ = style_encoder(rgb_img)
        content_feats = content_encoder(skt_img)
        gimg = decoder(content_feats, style_vector)

        loss_rec_org = F.mse_loss(gimg, rgb_img)
        if DATA_NAME != 'shoe':
            loss_rec_org += percept(
                F.adaptive_avg_pool2d(gimg, output_size=256),
                F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
        #else:
        #    loss_rec_org += F.l1_loss(gimg, rgb_img)

        loss_rec = loss_rec_org
        if DATA_NAME == 'shoe':
            ### the grey image reconstruction
            perm = true_randperm(BATCH_SIZE_AE)
            gimg_perm = decoder(content_feats, [s[perm] for s in style_vector])
            gimg_grey = gimg_perm.mean(dim=1, keepdim=True)
            real_grey = rgb_img.mean(dim=1, keepdim=True)
            loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
            loss_rec += loss_rec_grey
        loss_rec.backward()

        opt_s.step()
        opt_d.step()
        opt_c.step()

        ### Logging
        losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE)
        losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE)
        losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
        losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE)
        losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE)
        if DATA_NAME == 'shoe':
            losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE)

        if iteration % LOG_INTERVAL == 0:
            log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f  rec_org: %.4f  cls: %.4f  style_consist: %.4f  content_consist: %.4f  rec_grey: %.4f'%(losses_rec_rd.avg, \
                    losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg)

            print(log_msg)

            if log_file_path is not None:
                log_file = open(log_file_path, 'a')
                log_file.write(log_msg + '\n')
                log_file.close()

            losses_sf_consist.reset()
            losses_cls.reset()
            losses_rec_rd.reset()
            losses_rec_org.reset()
            losses_cf_consist.reset()
            losses_rec_grey.reset()

        if iteration % SAVE_IMAGE_INTERVAL == 0:
            vutils.save_image(torch.cat([
                rgb_img_rd,
                F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd
            ]),
                              '%s/rd_%d.jpg' % (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))
            if DATA_NAME != 'shoe':
                with torch.no_grad():
                    perm = true_randperm(BATCH_SIZE_AE)
                    gimg_perm = decoder([c for c in content_feats],
                                        [s[perm] for s in style_vector])
            vutils.save_image(torch.cat([
                rgb_img,
                F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg,
                gimg_perm
            ]),
                              '%s/org_%d.jpg' %
                              (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))

        if iteration % SAVE_MODEL_INTERVAL == 0:
            print('Saving history model')
            torch.save(
                {
                    's': style_encoder.state_dict(),
                    'd': decoder.state_dict(),
                    'c': content_encoder.state_dict(),
                    'opt_c': opt_c.state_dict(),
                    'opt_s_cls': opt_s_cls.state_dict(),
                    'opt_s': opt_s.state_dict(),
                    'opt_d': opt_d.state_dict(),
                }, '%s/%d.pth' % (saved_model_folder, iteration))

    torch.save(
        {
            's': style_encoder.state_dict(),
            'd': decoder.state_dict(),
            'c': content_encoder.state_dict(),
            'opt_c': opt_c.state_dict(),
            'opt_s_cls': opt_s_cls.state_dict(),
            'opt_s': opt_s.state_dict(),
            'opt_d': opt_d.state_dict(),
        }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))