Example #1
0
 def initialize_validate_losses(self):
     total_losses = {}
     if 'sbir' in self.args.task_types:
         total_losses['triplet'] = AverageMeter()
     if 'sbir_cls' in self.args.task_types:
         total_losses['prediction'] = AverageMeter()
     return total_losses
Example #2
0
    def validate(self):
        netG, netD = self.models['netG'], self.models['netD']
        total_losses = {'gan_gen':AverageMeter(), 'gan_dis':AverageMeter()}

        total_evaluations = {'No':AverageMeter()}
        topk = (1,5)
        with torch.no_grad():
            for data_i, batch_data in enumerate(self.datas['val_loader']):
                batch_data = [tensor.to(device=self.gpu_ids['netG'],dtype=torch.float32) for tensor in batch_data]
                mask_input_states, input_states, length_mask, input_mask, target = batch_data
                # Size Checking: input_states[batch, seq_len, 5], length_mask[batch,seq_len], input_mask[batch, seq_len], target[batch]

                target = target.to(dtype=torch.long)

                # Fake Sequence Generation
                noise_input = self.get_random_input(input_states.size(0), input_states.size(1), self.args.noise_dim, device=self.gpu_ids['netG'])

                fake_states = netG(noise_input, attention_mask=None, head_mask=None, **self.running_paras)

                # Add gan label for input
                prefix_input = torch.ones(input_states.size(0), 1, input_states.size(2)).to(dtype=torch.float, device=self.gpu_ids['netG'])
                dis_input_states = torch.cat([prefix_input*(-1), input_states], dim=1)
                dis_fake_states = torch.cat([prefix_input*(-1), fake_states], dim=1)

                # compute real and fake scores
                scores_real = netD(dis_input_states, attention_mask=None, head_mask=None, **self.running_paras)
                scores_fake = netD(dis_fake_states, attention_mask=None, head_mask=None, **self.running_paras)

                # Update the Discriminator
                # Size Checking: scores_real, scores_fake [batch, 2]
                total_d_loss, d_losses = self.calculate_d_losses(scores_real, scores_fake)

                # Update the Generator
                scores_fake = netD(dis_fake_states, attention_mask=None, head_mask=None, **self.running_paras)
                total_g_loss, g_losses = self.calculate_g_losses(scores_fake)

                evaluations = self.evaluate()
                losses = {**d_losses, **g_losses}
                for key in total_losses:
                    total_losses[key].update(losses[key].item(), input_states.size(0))
                for key in total_evaluations:
                    total_evaluations[key].update(evaluations[key], input_states.size(0))

        log_losses = {key:torch.tensor(total_losses[key].avg) for key in total_losses}
        log_evaluations = {key:total_evaluations[key].avg for key in total_evaluations}
        # Update Logs
        self.update_log('val', log_losses, log_evaluations)
        strokes_pred = fake_states
        self.update_stroke_results('val', input_states, strokes_pred)
Example #3
0
    def initialize_validate_losses(self):
        total_losses = {}
        if 'maskrec' in self.args.task_types:
            types = [
                'mask_axis', 'rec_axis', 'mask_type', 'rec_type'
            ]  # , 'mask_type', 'rec_type' 'mask_axis', 'rec_axis', 'mask_type', 'rec_type'
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()
        if 'maskgmm' in self.args.task_types:
            types = ['mask_gmm', 'rec_gmm', 'mask_type', 'rec_type']
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()

        if 'maskdisc' in self.args.task_types:
            types = [
                'x_mask_disc', 'y_mask_disc', 'type_mask_disc', 'x_rec_disc',
                'y_rec_disc', 'type_rec_disc'
            ]
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()
        if 'sketchcls' in self.args.task_types:
            total_losses['prediction'] = AverageMeter()
        if 'sketchclsinput' in self.args.task_types:
            total_losses['prediction'] = AverageMeter()
        if 'sketchretrieval' in self.args.task_types:
            total_losses['triplet'] = AverageMeter()
        return total_losses
Example #4
0
    def initialize_validate_losses(self):
        total_losses = {}
        if 'maskrec' in self.args.task_types:
            types = ['mask_l1', 'rec_l1', 'mask_type', 'rec_type']
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()
        if 'maskgmm' in self.args.task_types:
            types = ['mask_gmm', 'rec_gmm', 'mask_type', 'rec_type']
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()

        if 'maskdisc' in self.args.task_types:
            types = [
                'x_mask_disc', 'y_mask_disc', 'type_mask_disc', 'x_rec_disc',
                'y_rec_disc', 'type_rec_disc'
            ]
            for t in types:
                if t in self.valid_losses:
                    total_losses[t] = AverageMeter()
        return total_losses
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=device,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        #masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 250)
        # mask is 1 on masked region

        coarse_imgs, recon_imgs = netG(imgs, masks)
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)
        '''
def validate(net, dataloader, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    net.eval()
    with torch.no_grad():
        end = time.time()
        for i, (imgs, cls_ids) in enumerate(dataloader):
            imgs, cls_ids = imgs.to(device), cls_ids.to(device)
            masks = generate_mask()
            masks = masks.to(device)
            pred = net(imgs * (1 - masks))
            loss = criterion(pred, cls_ids)
            #measure
            prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5))
            losses.update(loss.item(), imgs.size(0))
            top1.update(prec1[0], imgs.size(0))
            top5.update(prec5[0], imgs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % print_freq == 0:
                logger.info(
                    'Test: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        i,
                        len(dataloader),
                        batch_time=batch_time,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    return top1.avg, top5.avg
Example #7
0
def validate(nets,
             loss_terms,
             opts,
             dataloader,
             epoch,
             network_type,
             devices=(cuda0, cuda1),
             batch_n="whole_test_show"):
    """
    validate phase
    """
    netD, netG = nets["netD"], nets["netG"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[
            "GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "p_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_comp_dir = os.path.join(val_save_dir, "comp")
    for size in SIZES_TAGS:
        if not os.path.exists(os.path.join(val_save_real_dir, size)):
            os.makedirs(os.path.join(val_save_real_dir, size))
        if not os.path.exists(os.path.join(val_save_gen_dir, size)):
            os.makedirs(os.path.join(val_save_gen_dir, size))
        if not os.path.exists(os.path.join(val_save_comp_dir, size)):
            os.makedirs(os.path.join(val_save_comp_dir, size))
    info = {}
    t = 0
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)

        for s_i, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if imgs.size(1) != 3:
                print(t, imgs.size())
            pre_inter_imgs = F.interpolate(pre_complete_imgs, size)

            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            # masks = (masks > 0).type(torch.FloatTensor)

            # imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            if network_type == 'l2h_unet':
                recon_imgs = netG(imgs, masks, pre_complete_imgs,
                                  pre_inter_imgs, size)
            elif network_type == 'l2h_gated':
                recon_imgs = netG(imgs, masks, pre_inter_imgs)
            elif network_type == 'sa_gated':
                recon_imgs, _ = netG(imgs, masks)
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [recon_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(
                imgs, complete_imgs)
            p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss  # g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(s_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)

            # Logger logging

            # if t < config.STATIC_VIEW_SIZE:
            print(i, size)
            real_img = img2photo(imgs)
            gen_img = img2photo(recon_imgs)
            comp_img = img2photo(complete_imgs)

            real_img = Image.fromarray(real_img[0].astype(np.uint8))
            gen_img = Image.fromarray(gen_img[0].astype(np.uint8))
            comp_img = Image.fromarray(comp_img[0].astype(np.uint8))
            real_img.save(
                os.path.join(val_save_real_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            gen_img.save(
                os.path.join(val_save_gen_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))
            comp_img.save(
                os.path.join(val_save_comp_dir, SIZES_TAGS[s_i],
                             "{}.png".format(i)))

            end = time.time()
Example #8
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()

    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        #traditional inpainting
        for item in range(len(imgs)):
            img = np.array(transforms.ToPILImage()(imgs[item]))
            mask = np.array(transforms.ToPILImage()(masks[item]))
            res = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
            res = transforms.ToTensor()(res)
            res = (res * 255) / 127.5 - 1
            if item:
                traditional_inpaint = torch.cat((traditional_inpaint, res))
            else:
                traditional_inpaint = res
        traditional_inpaint = torch.reshape(traditional_inpaint,
                                            (config.BATCH_SIZE, 3, 256, 256))
        traditional_inpaint = traditional_inpaint.to(device)

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        #print(type(masks))
        #print(type(guidence))
        #exit()

        coarse_imgs, recon_imgs_with_weight = netG(imgs, masks)
        recon_imgs = recon_imgs_with_weight[:, 0:3, :, :]
        weight_layer = (recon_imgs_with_weight[:, 3:, :, :] + 1.0) / 2
        recon_imgs = weight_layer * recon_imgs + (
            1 - weight_layer) * traditional_inpaint
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks), coarse_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images,
                                                epoch * len(dataloader) + i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
def train(net, dataloader, epoch, opt, criterion):
    net.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for i, (imgs, cls_ids) in enumerate(dataloader):
        data_time.update(time.time() - end)

        imgs, cls_ids = imgs.to(device), cls_ids.to(device)
        opt.zero_grad()
        masks = generate_mask()
        masks = masks.to(device)
        if np.random.rand() < mask_rate:
            pred = net(imgs * (1 - masks))
        else:
            pred = net(imgs)
        loss = criterion(pred, cls_ids)

        loss.backward()
        opt.step()

        #measure
        prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5))
        losses.update(loss.item(), imgs.size(0))
        top1.update(prec1[0], imgs.size(0))
        top5.update(prec5[0], imgs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            logger.info('Epoch: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                            epoch,
                            i,
                            len(dataloader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top1=top1,
                            top5=top5))
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()
        guide = []
        transform = transforms.Compose([transforms.ToPILImage()])
        for k in range(imgs.shape[0]):
            im = transform(imgs[k])
            im = np.array(im)
            # cv2.imwrite('test.jpg', im)

            im = cv2.Canny(image=im, threshold1=20, threshold2=220)
            # cv2.imwrite('test1.jpg', im)
            # exit(1)

            guide.append(im)
        guide = torch.FloatTensor(guide)
        guide = guide[:, None, :, :]
        imgs, masks, guide = imgs.to(device), masks.to(device), guide.to(
            device)

        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        guide = guide / 255.0

        coarse_imgs, recon_imgs, attention = netG(imgs, masks, guide)
        # print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        # pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info(
                "Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks), coarse_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images,
                                                epoch * len(dataloader) + i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:
            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0,
             batch_n="whole"):
    """
    validate phase
    """
    netG.to(device)
    netD.to(device)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_raw_dir = os.path.join(val_save_dir, 'raw')
    # val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_raw_dir)
    info = {}

    for i, (imgs, masks, mean, std) in enumerate(dataloader):

        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        # imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging

        if i + 1 < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return (imgs * 255).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()
                # return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()

            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = img2photo(
                torch.cat(
                    [((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]) *
                     (1 - masks) + masks,
                     ((coarse_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0]),
                     ((recon_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0]),
                     ((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]),
                     ((complete_imgs * std.cpu().numpy()[0]) +
                      mean.cpu().numpy()[0])],
                    dim=3))

        else:
            logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))

            # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]],
            #          epoch*len(dataloader)+i, win='validation_loss', update='append')
            j = 0
            for tag, images in info.items():
                h, w = images.shape[1], images.shape[2] // 5
                for val_img in images:
                    raw_img = val_img[:, 0:w, :]
                    # raw_img = ((raw_img - np.min(raw_img)) / np.max(raw_img)) * 255
                    real_img = val_img[:, (3 * w):(4 * w), :]
                    # real_img = ((real_img - np.min(real_img)) / np.max(real_img)) * 255
                    gen_img = val_img[:, (4 * w):, :]
                    # gen_img = ((gen_img - np.min(gen_img)) / np.max(gen_img)) * 255

                    cv2.imwrite(
                        os.path.join(val_save_real_dir, "{}.png".format(j)),
                        real_img)
                    cv2.imwrite(
                        os.path.join(val_save_gen_dir, "{}.png".format(j)),
                        gen_img)
                    cv2.imwrite(
                        os.path.join(val_save_raw_dir, "{}.png".format(j)),
                        raw_img)
                    j += 1
                # tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            # vis.line([[fid_score.item(),ssim_score.item()]], [epoch*len(dataloader)+i], win='validation_metric', update='append')
            # tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i)
            # tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i)
            break

        end = time.time()
    # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]],
    #          [epoch], win='validation_loss', update='append')
    wandb.log({
        "val_r_loss":
        losses['r_loss'].out(),
        "val_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(),
        "val_d_loss":
        losses['d_loss'].out()
    })
Example #12
0
    def validate(self):
        netE = self.models['netE']
        netE_rnn = self.models['netE_rnn']
        total_losses = self.initialize_validate_losses()
        topk = (1, 5)
        total_evaluations = {
            'accuracy_{}'.format(k): AverageMeter()
            for k in topk
        }
        total_strokes, total_strokes_pred,total_strokes_pred_rnn, total_input_masks, total_retrieval_outputs, total_targets = [], [], [], [], [],[]
        with torch.no_grad():
            start_time = time.time()
            for data_i, batch_data in enumerate(self.datas['val_loader']):
                end_time = time.time()
                self.data_time = end_time - start_time
                if 'sketchretrieval' not in self.args.task_types:
                    batch_data = [[
                        term,
                    ] for term in batch_data]

                batch_data = [[
                    t.to(device=self.gpu_ids['netE'], dtype=torch.float32)
                    for t in tensor
                ] for tensor in batch_data]
                mask_input_states, input_states, segments, length_masks, input_masks, targets = batch_data
                #print('mask', mask_input_states[0].size(), 'len',len(mask_input_states))
                # Size Checking: input_states[batch, seq_len, 5], length_mask[batch,seq_len], input_mask[batch, seq_len], target[batch]
                segments = [
                    segment.to(dtype=torch.long) for segment in segments
                ]
                targets = [t.to(dtype=torch.long) for t in targets]
                output_states, output_states_rnn, pooled_outputs, pooled_outputs_rnn = [], [], [], []

                for mask_input_state, length_mask, segment in zip(
                        mask_input_states, length_masks, segments):
                    output_state = netE(mask_input_state,
                                        length_mask,
                                        segments=segment,
                                        head_mask=None,
                                        **self.running_paras)
                    output_state_rnn = netE_rnn(mask_input_state,
                                                length_mask,
                                                segments=segment,
                                                head_mask=None,
                                                **self.running_paras)
                    if self.args.output_attentions:
                        output_state, attention_prob = output_state
                        output_state_rnn, attention_prob = output_state_rnn
                    output_states.append(output_state)
                    output_states_rnn.append(output_state_rnn)

                    pooled_output = {
                        task: self.models[task](output_state)
                        for task in self.args.task_types
                    }
                    pooled_output_rnn = {
                        task: self.models[task](output_state_rnn)
                        for task in self.args.task_types
                    }
                    pooled_outputs.append(pooled_output)
                    pooled_outputs_rnn.append(pooled_output_rnn)
                    #print(pooled_output['sketchclsinput'].topk(5, 1 , True, True) )
                strokes_pred = None
                if 'maskrec' in self.args.task_types:
                    strokes_pred = pooled_outputs[0]['maskrec'].cpu()
                    strokes_pred_rnn = pooled_outputs_rnn[0]['maskrec'].cpu()
                if 'maskgmm' in self.args.task_types:
                    strokes_pred = pooled_outputs[0]['maskgmm'].cpu()
                if 'maskdisc' in self.args.task_types:
                    strokes_pred = torch.cat(
                        [
                            pooled_outputs[0]['maskdisc'][0].argmax(
                                dim=2, keepdim=True),
                            pooled_outputs[0]['maskdisc'][1].argmax(
                                dim=2, keepdim=True),
                            pooled_outputs[0]['maskdisc'][2].argmax(
                                dim=2, keepdim=True)
                        ],
                        dim=2).to(dtype=torch.float).cpu()

                total_strokes.append(input_states[0].cpu())
                total_strokes_pred.append(strokes_pred)
                total_strokes_pred_rnn.append(strokes_pred_rnn)
                total_input_masks.append(input_masks[0].cpu())

                total_targets.append(targets[0])
                total_loss, losses = self.calculate_losses(
                    input_states, length_masks, input_masks, pooled_outputs,
                    targets)
                acc_evaluations = None
                if 'sketchclsinput' in self.args.task_types:
                    acc_evaluations = self.accuracy_evaluation(
                        pooled_outputs[0]['sketchclsinput'], targets[0])
                if 'sketchretrieval' in self.args.task_types:
                    total_retrieval_outputs.append(
                        pooled_outputs[0]['sketchretrieval'])
                if losses is not None:
                    for key in total_losses:
                        total_losses[key].update(losses[key].item(),
                                                 input_states[0].size(0))
                if acc_evaluations is not None:
                    for key in total_evaluations:
                        total_evaluations[key].update(acc_evaluations[key],
                                                      targets[0].size(0))
                self.run_time = time.time() - end_time
                self.batch_time = self.data_time + self.run_time
        retrieval_evaluations = {}
        if 'sketchretrieval' in self.args.task_types:
            retrieval_evaluations = self.retrieval_evaluation(
                torch.cat(total_retrieval_outputs, dim=0),
                torch.cat(total_targets, dim=0),
                self.cate_num,
                topk=topk)
        log_losses = {
            key: torch.tensor(total_losses[key].avg)
            for key in total_losses
        }
        log_evaluations = {
            key: total_evaluations[key].avg
            for key in total_evaluations
        }
        if self.best_accuracy <= log_evaluations['accuracy_1']:
            self.best_accuracy = log_evaluations['accuracy_1']
            self.best_t = self.counters['t']
        log_evaluations = {**log_evaluations, **retrieval_evaluations}
        # Update Logs
        self.update_log('val', log_losses, log_evaluations)

        if total_strokes_pred[0] is not None:
            total_strokes = torch.cat(total_strokes, dim=0)
            total_strokes_pred = torch.cat(total_strokes_pred, dim=0)
            total_strokes_pred_rnn = torch.cat(total_strokes_pred_rnn, dim=0)
            total_input_masks = torch.cat(total_input_masks, dim=0)
            self.update_stroke_results('val', total_strokes,
                                       total_strokes_pred,
                                       total_strokes_pred_rnn,
                                       total_input_masks)
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    # wandb.watch(netG, netD)
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()

    for i, (imgs, masks, mean, std) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        # imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region

        coarse_imgs, recon_imgs = netG(imgs, masks)
        #print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'ReconLoss': losses['r_loss'],
                "GANLoss": losses['g_loss'],
                "DLoss": d_loss.item()
            }

            # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]],
            #          [epoch*len(dataloader)+i], win='train_loss', update='append')

            # for tag, value in info_terms.items():
            #     tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)
            #
            # for tag, value in losses.items():
            #     tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                # return (((imgs*0.263)+0.472)*255).transpose(1,2).transpose(2,3).detach().cpu().numpy()
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks) + masks, coarse_imgs, recon_imgs,
                        imgs, complete_imgs
                    ],
                              dim=3))
            }

            # for tag, images in info.items():
            #     tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
    # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]],
    #          [epoch], win='train_loss', update='append')
    wandb.log({
        "train_r_loss":
        losses['r_loss'].out(),
        "train_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(),
        "train_d_loss":
        losses['d_loss'].out()
    })
Example #14
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)#
        #print(len([i for i in masks.numpy().flatten() if i != 0]))

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        #print(pred_pos.size())
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if i % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                        .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            info = {
                'train/ori_imgs':
                img2photo(imgs),
                'train/coarse_imgs':
                img2photo(coarse_imgs),
                'train/recon_imgs':
                img2photo(recon_imgs),
                'train/comp_imgs':
                img2photo(complete_imgs),
                'train/whole_imgs':
                img2photo(
                    torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, i)
        end = time.time()
Example #15
0
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0):
    """
    validate phase
    """
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs = netG(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging
        logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                    "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                    .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                    ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))

        if i * config.BATCH_SIZE < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            info = {
                'val/ori_imgs':
                img2photo(imgs),
                'val/coarse_imgs':
                img2photo(coarse_imgs),
                'val/recon_imgs':
                img2photo(recon_imgs),
                'val/comp_imgs':
                img2photo(complete_imgs),
                'val/whole_imgs':
                img2photo(
                    torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, i)
        end = time.time()
def validate(nets,
             loss_terms,
             opts,
             dataloader,
             epoch,
             devices=(cuda0, cuda1),
             batch_n="whole"):
    """
    validate phase
    """
    netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"]
    ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[
            "GANLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netSR.to(device0)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "p_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()

    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}

    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        pre_imgs = ori_imgs
        pre_complete_imgs = (pre_imgs / 127.5 - 1)
        pre_complete_imgs = pre_complete_imgs * (
            1 - ori_masks['val']) + ori_masks['val']
        pre_inter_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0])
        for s_j, size in enumerate(TRAIN_SIZES):

            masks = ori_masks['val']
            masks = F.interpolate(masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if s_j == 0:
                pre_inter_imgs = F.interpolate(pre_complete_imgs, size)
            else:
                pre_complete_imgs = (pre_complete_imgs + 1) * 127.5
                pre_inter_imgs = netSR(pre_complete_imgs, 2)
                pre_inter_imgs = (pre_inter_imgs / 127.5 - 1)
            #upsampled_imgs = pre_inter_imgs
            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            #masks = (masks > 0).type(torch.FloatTensor)
            upsampled_imgs = pre_inter_imgs
            #imgs, masks = imgs.to(device), masks.to(device)
            imgs = (imgs / 127.5 - 1)
            # mask is 1 on masked region
            # forward
            recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs,
                              size)

            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [complete_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

            g_loss = GANLoss(pred_neg)

            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs)
            #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            p_loss = p_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss + g_loss  #+ s_loss#g_loss + r_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(0, imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            pre_complete_imgs = complete_imgs
            # Update time recorder
            batch_time.update(time.time() - end)

            # Logger logging

            if i + 1 < config.STATIC_VIEW_SIZE:

                def img2photo(imgs):
                    return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                        2, 3).detach().cpu().numpy()

                # info = { 'val/ori_imgs':img2photo(imgs),
                #          'val/coarse_imgs':img2photo(coarse_imgs),
                #          'val/recon_imgs':img2photo(recon_imgs),
                #          'val/comp_imgs':img2photo(complete_imgs),
                info['val/{}whole_imgs/{}'.format(size, i)] = img2photo(
                    torch.cat([
                        imgs * (1 - masks), upsampled_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))

            else:
                logger.info("Validation Epoch {0}, [{1}/{2}]: Size:{size}, Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                            "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f},\t Perc Loss:{p_loss.val:.4f},\tStyle Loss:{s_loss.val:.4f}"
                            .format(epoch, i+1, len(dataloader),size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                            ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss']))
                j = 0
                for size in SIZES_TAGS:
                    if not os.path.exists(os.path.join(val_save_real_dir,
                                                       size)):
                        os.makedirs(os.path.join(val_save_real_dir, size))
                        os.makedirs(os.path.join(val_save_gen_dir, size))

                for tag, images in info.items():
                    h, w = images.shape[1], images.shape[2] // 5
                    s_i = 0
                    for i_, s in enumerate(TRAIN_SIZES):
                        if "{}".format(s) in tag:
                            size_tag = "{}".format(s)
                            s_i = i_
                            break

                    for val_img in images:
                        real_img = val_img[:, (3 * w):(4 * w), :]
                        gen_img = val_img[:, (4 * w):, :]
                        real_img = Image.fromarray(real_img.astype(np.uint8))
                        gen_img = Image.fromarray(gen_img.astype(np.uint8))
                        real_img.save(
                            os.path.join(val_save_real_dir, SIZES_TAGS[s_i],
                                         "{}_{}.png".format(size_tag, j)))
                        gen_img.save(
                            os.path.join(val_save_gen_dir, SIZES_TAGS[s_i],
                                         "{}_{}.png".format(size_tag, j)))
                        j += 1
                    tensorboardlogger.image_summary(tag, images, epoch)
                path1, path2 = os.path.join(
                    val_save_real_dir,
                    SIZES_TAGS[len(SIZES_TAGS) - 1]), os.path.join(
                        val_save_gen_dir, SIZES_TAGS[len(SIZES_TAGS) - 1])
                fid_score = metrics['fid']([path1, path2], cuda=False)
                ssim_score = metrics['ssim']([path1, path2])
                tensorboardlogger.scalar_summary('val/fid', fid_score.item(),
                                                 epoch * len(dataloader) + i)
                tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(),
                                                 epoch * len(dataloader) + i)
                break

            end = time.time()
    saved_model = {
        'epoch': epoch + 1,
        'netG_state_dict': netG.to(cpu0).state_dict(),
        'netD_state_dict': netD.to(cpu0).state_dict(),
        # 'optG' : optG.state_dict(),
        # 'optD' : optD.state_dict()
    }
    torch.save(saved_model,
               '{}/latest_ckpt.pth.tar'.format(log_dir, epoch + 1))
def train(nets,
          loss_terms,
          opts,
          dataloader,
          epoch,
          devices=(cuda0, cuda1),
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)
    """
    netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"]
    ReconLoss, DLoss, GANLoss, PercLoss, StyleLoss = loss_terms[
        'ReconLoss'], loss_terms['DLoss'], loss_terms['GANLoss'], loss_terms[
            "PercLoss"], loss_terms["StyleLoss"]
    optG, optD = opts['optG'], opts['optD']
    device0, device1 = devices
    netG.to(device0)
    netD.to(device0)
    netSR.to(device0)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "s_loss": AverageMeter(),
        'p_loss': AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (ori_imgs, ori_masks) in enumerate(dataloader):
        ff_mask, rect_mask = ori_masks['random_free_form'], ori_masks[
            'random_bbox']
        if np.random.rand() < 0.3:
            ori_masks = rect_mask
        else:
            ori_masks = ff_mask
        # Optimize Discriminator

        # mask is 1 on masked region
        pre_complete_imgs = ori_imgs
        pre_complete_imgs = (pre_complete_imgs / 127.5 - 1)
        pre_complete_imgs = pre_complete_imgs * (1 - ori_masks) + ori_masks
        pre_complete_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0])

        for s_j, size in enumerate(TRAIN_SIZES):
            data_time.update(time.time() - end)
            optD.zero_grad(), netD.zero_grad(), netG.zero_grad(
            ), optG.zero_grad()
            #Reshape
            masks = F.interpolate(ori_masks, size)
            masks = (masks > 0).type(torch.FloatTensor)
            imgs = F.interpolate(ori_imgs, size)
            if s_j == 0:
                pre_inter_imgs = F.interpolate(pre_complete_imgs, size)
            else:
                pre_complete_imgs = (pre_complete_imgs + 1) * 127.5
                pre_inter_imgs = netSR(pre_complete_imgs, 2)
                pre_inter_imgs = (pre_inter_imgs / 127.5 - 1)
            imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to(
                device0), masks.to(device0), pre_complete_imgs.to(
                    device0), pre_inter_imgs.to(device0)
            imgs = (imgs / 127.5 - 1)
            upsampled_imgs = pre_inter_imgs

            recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs,
                              size)
            #print(attention.size(), )
            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [complete_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
            d_loss = DLoss(pred_pos, pred_neg)
            losses['d_loss'].update(d_loss.item(), imgs.size(0))
            #print(size)
            if i % 3:
                d_loss.backward(retain_graph=True)

                optD.step()

            # Optimize Generator
            optD.zero_grad(), netD.zero_grad(), optG.zero_grad(
            ), netG.zero_grad()
            pred_neg = netD(neg_imgs)
            #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
            g_loss = GANLoss(pred_neg)
            r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks)

            imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to(
                device1), complete_imgs.to(device1)
            p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs)
            #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs)
            #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0)
            p_loss = p_loss.to(device0)
            imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to(
                device0), complete_imgs.to(device0)

            whole_loss = r_loss + p_loss + g_loss

            # Update the recorder for losses
            losses['g_loss'].update(g_loss.item(), imgs.size(0))
            losses['p_loss'].update(p_loss.item(), imgs.size(0))
            losses['s_loss'].update(0, imgs.size(0))
            losses['r_loss'].update(r_loss.item(), imgs.size(0))
            losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
            whole_loss.backward(retain_graph=True)

            optG.step()

            pre_complete_imgs = complete_imgs

            # Update time recorder
            batch_time.update(time.time() - end)

            if (i + 1) % config.SUMMARY_FREQ == 0:
                # Logger logging
                logger.info("Epoch {0}, [{1}/{2}]:Size:{size} Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                            "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t Perc Loss:{p_loss.val:.4f}, \t Style Loss:{s_loss.val:.4f}" \
                            .format(epoch, i+1, len(dataloader), size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                            ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss']))
                # Tensorboard logger for scaler and images
                info_terms = {
                    '{}WGLoss'.format(size): whole_loss.item(),
                    '{}ReconLoss'.format(size): r_loss.item(),
                    "{}GANLoss".format(size): g_loss.item(),
                    "{}DLoss".format(size): d_loss.item(),
                    "{}PercLoss".format(size): p_loss.item()
                }

                for tag, value in info_terms.items():
                    tensorboardlogger.scalar_summary(
                        tag, value,
                        epoch * len(dataloader) + i)

                for tag, value in losses.items():
                    tensorboardlogger.scalar_summary(
                        'avg_' + tag, value.avg,
                        epoch * len(dataloader) + i)

                def img2photo(imgs):
                    return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                        2, 3).detach().cpu().numpy()

                # info = { 'train/ori_imgs':img2photo(imgs),
                #          'train/coarse_imgs':img2photo(coarse_imgs),
                #          'train/recon_imgs':img2photo(recon_imgs),
                #          'train/comp_imgs':img2photo(complete_imgs),
                info = {
                    'train/{}whole_imgs'.format(size):
                    img2photo(
                        torch.cat([
                            imgs * (1 - masks), upsampled_imgs, recon_imgs,
                            imgs, complete_imgs
                        ],
                                  dim=3))
                }

                for tag, images in info.items():
                    tensorboardlogger.image_summary(
                        tag, images,
                        epoch * len(dataloader) + i)
            end = time.time()
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:
            validate(nets,
                     loss_terms,
                     opts,
                     val_datas,
                     epoch,
                     devices,
                     batch_n=i)
            netG.train()
            netD.train()
            netG.to(device0)
            netD.to(device0)
Example #18
0
    def validate(self):
        netE = self.models['netE']
        total_losses = self.initialize_validate_losses()
        topk = (1,5)
        total_evaluations = {'accuracy_{}'.format(k): AverageMeter() for k in topk}
        total_strokes, total_strokes_pred, total_input_masks, total_retrieval_outputs, total_targets = [], [], [], [], []
        with torch.no_grad():
            start_time = time.time()
            for data_i, batch_data in enumerate(self.datas['val_loader']):
                end_time = time.time()
                self.data_time = end_time - start_time
                if 'sketchretrieval' not in self.args.task_types:
                    batch_data = [[term ,] for term in batch_data]

                batch_data = [[t.to(device=self.gpu_ids['netE'],dtype=torch.float32) for t in tensor ] for tensor in batch_data]
                masked_stroke_segments, stroke_segments, segments, segment_indexs, segment_atten_masks, input_masks, targets = batch_data

                # Size Checking: input_states[batch, seq_len, 5], length_mask[batch,seq_len], input_mask[batch, seq_len], target[batch]
                segments = [segment.to(dtype=torch.long) for segment in segments]
                targets = [t.to(dtype=torch.long) for t in targets]

                output_states, pooled_outputs = [], []
                for masked_stroke_segment, segment_atten_mask, segment, segment_index in zip(masked_stroke_segments, segment_atten_masks, segments, segment_indexs):
                    output_state = netE(masked_stroke_segment, segment, segment_atten_mask, segment_index, head_mask=None, **self.running_paras)

                    if self.args.output_attentions:
                        output_state, attention_prob = output_state
                    output_states.append(output_state)

                    pooled_output = {task:self.models[task](output_state) for task in self.args.task_types if task != 'segorder'}
                    if 'segorder' in self.args.task_types:
                        pooled_output['segorder'] = self.models['segorder'](output_state, segment_index)
                    pooled_outputs.append(pooled_output)

                # Size Checking: output_states, pooled_output[[batch, seq_len, 6*M+3],[]]
                # All for display, so all need remove segment
                strokes_pred = None
                if 'maskrec' in self.args.task_types:
                    strokes_pred = self.remove_segment(pooled_outputs[0]['maskrec'], segment_indexs[0]).cpu()
                if 'maskgmm' in self.args.task_types:
                    strokes_pred = self.remove_segment(pooled_outputs[0]['maskgmm'], segment_indexs[0]).cpu()
                if 'maskdisc' in self.args.task_types:
                    strokes_pred = self.remove_segment(torch.cat([pooled_outputs[0]['maskdisc'][0].argmax(dim=2,keepdim=True), pooled_outputs[0]['maskdisc'][1].argmax(dim=2,keepdim=True), pooled_outputs[0]['maskdisc'][2].argmax(dim=2,keepdim=True)], dim=2).to(dtype=torch.float), segment_indexs[0]).cpu()

                total_strokes.append(self.remove_segment(stroke_segments[0], segment_indexs[0]).cpu())
                total_strokes_pred.append(strokes_pred)
                total_input_masks.append(self.remove_segment(input_masks[0], segment_indexs[0]).cpu())

                total_targets.append(targets[0])
                total_loss, losses = self.calculate_losses(stroke_segments, segment_atten_masks, segments, segment_indexs, input_masks, pooled_outputs, targets)
                acc_evaluations = None
                if 'sketchcls' in self.args.task_types:
                    acc_evaluations = self.accuracy_evaluation(pooled_outputs[0]['sketchcls'], targets[0])
                if 'sketchclsinput' in self.args.task_types:
                    acc_evaluations = self.accuracy_evaluation(pooled_outputs[0]['sketchclsinput'], targets[0])
                if 'sketchretrieval' in self.args.task_types:
                    total_retrieval_outputs.append(pooled_outputs[0]['sketchretrieval'])
                if losses is not None:
                    for key in total_losses:
                        total_losses[key].update(losses[key].item(), stroke_segments[0].size(0))
                if acc_evaluations is not None:
                    for key in total_evaluations:
                        total_evaluations[key].update(acc_evaluations[key], stroke_segments[0].size(0))
                self.run_time = time.time() - end_time
                self.batch_time = self.data_time + self.run_time

        retrieval_evaluations = {}
        if 'sketchretrieval' in self.args.task_types:
            retrieval_evaluations = self.retrieval_evaluation(torch.cat(total_retrieval_outputs, dim=0), torch.cat(total_targets, dim=0), topk=topk)
        log_losses = {key:torch.tensor(total_losses[key].avg) for key in total_losses}
        log_evaluations = {key:total_evaluations[key].avg for key in total_evaluations}
        log_evaluations = {**log_evaluations, **retrieval_evaluations}
        # Update Logs
        self.update_log('val', log_losses, log_evaluations)

        if total_strokes_pred[0] is not None:
            total_strokes = torch.cat(total_strokes, dim=0)
            total_strokes_pred = torch.cat(total_strokes_pred, dim=0)
            total_input_masks = torch.cat(total_input_masks, dim=0)
            self.update_stroke_results('val', total_strokes, total_strokes_pred, total_input_masks)
Example #19
0
def pretrainD(netG, netD, GANLoss, ReconLoss, DLoss, NLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    logger.info("Pretraining D epoch %d"%epoch)
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter(), 'n_loss': AverageMeter()}

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks, gray) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['mine']
        # masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks, gray = imgs.to(device), masks.to(device), gray.to(device)
        # print(imgs.shape)
        masks = 1 - masks / 255.0 
        # masks = masks / 255.0 
        # 1 for masks, areas with holes
        # print(masks.min(), masks.max())
        imgs = (imgs / 127.5 - 1)
        gray = (gray / 127.5 - 1)
        # mask is 1 on masked region

        coarse_imgs, refined, mixed = netG(gray, masks)
        # coarse_imgs, mixed = netG(imgs, masks)
        # coarse_imgs, mixed, attention = netG(imgs, masks)
        #print(attention.size(), )
        # complete_imgs = mixed * masks + imgs * (1 - masks)
        complete_imgs = mixed # * masks + imgs * (1 - masks)
        # print(imgs.cpu().detach().max(), imgs.cpu().detach().min(), mixed.cpu().detach().max(), mixed.cpu().detach().min(), masks.cpu().detach().max(), masks.cpu().detach().min(), complete_imgs.cpu().detach().max(), complete_imgs.cpu().detach().min())

        pos_imgs = imgs
        neg_imgs = complete_imgs
        # pos_imgs = torch.cat([imgs], dim=1)
        # pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        # neg_imgs = torch.cat([complete_imgs], dim=1)
        # neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()
        batch_time.update(time.time() - end)

        if (i+1) % config.SUMMARY_FREQ == 0:
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, \t D Loss: {d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, d_loss=losses['d_loss']))
Example #20
0
def train(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG, netD = nets["netG"], nets["netD"]
    GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"]
    optG, optD = opts["optG"], opts["optD"]
    device0, device1 = devices[0], devices[1]
    netG.to(device0)
    netD.to(device0)
    # maskNetD.to(device1)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(),}
              # 'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()}

    netG.train()
    netD.train()
    # maskNetD.train()
    end = time.time()
    for i, data in enumerate(dataloader):
        data_time.update(time.time() - end)
        imgs, img_exs, masks = data
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0)
        imgs = (imgs / 127.5 - 1)
        img_exs = (img_exs / 127.5 - 1)
        # mask is 1 on masked region
        coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)
        #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0)

        # Discriminator Loss
        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        d_loss.backward(retain_graph=True)
        optD.step()


        # Mask Discriminator Loss
        # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1)
        # masks = masks.to(device1)
        # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs)
        # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0)
        # mask_d_loss = DLoss(mask_pred_pos*masks , mask_pred_neg*masks )
        # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks, masks=masks)

        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0))
        # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0))
        # mask_whole_loss = mask_rec_loss
        # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0))
        # mask_whole_loss.backward(retain_graph=True)
        # optMD.step()


        # Optimize Generator
        # masks = masks.to(device0)
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad(),# optMD.zero_grad(), maskNetD.zero_grad()
        pred_neg = netD(neg_imgs)
        #pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)
        r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs)

        whole_loss = g_loss + r_loss + r_ex_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i+1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f}, \t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, " \
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time \
                                ,whole_loss=losses['whole_loss'], r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item(), }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()

            info = {
                     'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3))
                     }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)

        if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            validate(nets, loss_terms, opts, val_datas , epoch, devices, batch_n=i)
            netG.train()
            netD.train()
            #maskNetD.train()
        end = time.time()
Example #21
0
def train(netG, netD, GANLoss, ReconLoss, DLoss, NLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter(), 'n_loss': AverageMeter()}

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks, gray) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['mine']
        # masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()

        imgs, masks, gray = imgs.to(device), masks.to(device), gray.to(device)
        # print(imgs.shape)
        masks = 1 - masks / 255.0 
        # masks = masks / 255.0 
        # 1 for masks, areas with holes
        # print(masks.min(), masks.max())
        imgs = (imgs / 127.5 - 1)
        gray = (gray / 127.5 - 1)
        # mask is 1 on masked region

        coarse_imgs, refined, mixed = netG(gray, masks)
        # coarse_imgs, mixed = netG(imgs, masks)
        # coarse_imgs, mixed, attention = netG(imgs, masks)
        #print(attention.size(), )
        # complete_imgs = mixed * masks + imgs * (1 - masks)
        complete_imgs = mixed # * masks + imgs * (1 - masks)
        # print(imgs.cpu().detach().max(), imgs.cpu().detach().min(), mixed.cpu().detach().max(), mixed.cpu().detach().min(), masks.cpu().detach().max(), masks.cpu().detach().min(), complete_imgs.cpu().detach().max(), complete_imgs.cpu().detach().min())

        pos_imgs = imgs
        neg_imgs = complete_imgs
        # pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        # neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()


        # Optimize Generator
        optD.zero_grad(), netD.zero_grad()
        optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, mixed, masks)
        n_loss = NLoss(coarse_imgs, refined, mixed, imgs)

        # whole_loss = r_loss + n_loss
        whole_loss = g_loss + r_loss + n_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['n_loss'].update(n_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        whole_loss.backward()

        optG.step()

        # print('w?', imgs.min(), imgs.max())

        # Update time recorder
        batch_time.update(time.time() - end)

        # print(((imgs+1)*127.5).min(), ((imgs+1)*127.5).max())
        if (i+1) % config.SUMMARY_FREQ == 0:
            # Logger logging
                        # "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}"
            logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}, \t D Loss: {d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], n_loss=losses['n_loss']))
                        # , n_loss=losses['n_loss']))
            # Tensorboard logger for scaler and images
            # info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item()}
            info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item()}

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i)

            def img2photo(imgs):
                # return ((imgs+1)*127.5).detach().cpu().numpy()
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()
            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/mixed':img2photo(mixed),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                     'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks) + masks, refined, imgs * masks, complete_imgs, imgs], dim=3))
                     }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i)
        if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:

            with torch.no_grad():
                validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas , epoch, device, batch_n=i)
            netG.train()
            # netD.train()
        end = time.time()
Example #22
0
def validate(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), batch_n="whole"):
    """
    validate phase
    """
    netG, netD  = nets["netG"], nets["netD"]
    GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"]
    optG, optD = opts["optG"], opts["optD"]
    device0, device1 = devices[0], devices[1]
    netG.to(device0)
    netD.to(device0)
    # maskNetD.to(device1)

    netG.eval()
    netD.eval()
    # maskNetD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(),
              'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()}

    end = time.time()
    val_save_dir = os.path.join(result_dir, "val_{}_{}".format(epoch, batch_n+1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}

    for i, data in enumerate(dataloader):

        data_time.update(time.time() - end, 1)
        imgs, img_exs, masks = data
        masks = masks['val']
        #masks = (masks > 0).type(torch.FloatTensor)

        imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0)
        imgs = (imgs / 127.5 - 1)
        img_exs = (img_exs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)
        #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)

        # # Mask Gan
        # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1)
        # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs)
        # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs)

        whole_loss = g_loss + r_loss + r_ex_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))

        # masks = masks.to(device1)
        # mask_d_loss = DLoss(mask_pred_pos*masks + (1-masks), mask_pred_neg*masks + (1-masks))
        # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks)
        # mask_whole_loss = mask_rec_loss

        # masks = masks.to(device0)
        # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0))
        # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0))
        # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0))

        # Update time recorder
        batch_time.update(time.time() - end, 1)


        # Logger logging

        if (i+1) < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()
            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = {"img":img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3)),
                                                   }

        else:
            logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                        "Recon Loss:{r_loss.val:.4f},\t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                        .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], \
                                r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'], g_loss=losses['g_loss'], d_loss=losses['d_loss']))

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('val/avg_'+tag, value.avg, epoch*len(dataloader)+i)
            j = 0
            for tag, datas in info.items():
                images = datas["img"]
                h, w = images.shape[1], images.shape[2] // 5
                for kv, val_img in enumerate(images):
                    real_img = val_img[:,(3*w):(4*w),:]
                    gen_img = val_img[:,(4*w):(5*w),:]
                    real_img = Image.fromarray(real_img.astype(np.uint8))
                    gen_img = Image.fromarray(gen_img.astype(np.uint8))
                    #pkl.dump({datas[term][kv] for term in datas if term != "img"}, open(os.path.join(val_save_inf_dir, "{}.png".format(j)), 'wb'))
                    real_img.save(os.path.join(val_save_real_dir, "{}.png".format(j)))
                    gen_img.save(os.path.join(val_save_gen_dir, "{}.png".format(j)))
                    j += 1
                tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i)
            tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i)
            break
            
        end = time.time()
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0,
             batch_n="whole"):
    """
    validate phase
    """
    netG.to(device)
    netD.to(device)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}
    for i, (imgs, masks) in enumerate(dataloader):

        data_time.update(time.time() - end)
        masks = masks['val']
        # masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs, attention = netG.forward(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging

        if i + 1 < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = img2photo(
                torch.cat([
                    imgs *
                    (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs
                ],
                          dim=3))

        else:
            logger.info(
                "Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            j = 0
            for tag, images in info.items():
                h, w = images.shape[1], images.shape[2] // 5
                for val_img in images:
                    real_img = val_img[:, (3 * w):(4 * w), :]
                    gen_img = val_img[:, (4 * w):, :]
                    real_img = Image.fromarray(real_img.astype(np.uint8))
                    gen_img = Image.fromarray(gen_img.astype(np.uint8))
                    real_img.save(
                        os.path.join(val_save_real_dir, "{}.png".format(j)))
                    gen_img.save(
                        os.path.join(val_save_gen_dir, "{}.png".format(j)))
                    j += 1
                tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            tensorboardlogger.scalar_summary('val/fid', fid_score.item(),
                                             epoch * len(dataloader) + i)
            tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(),
                                             epoch * len(dataloader) + i)
            break

        end = time.time()
Example #24
0
    def validate(self):
        netSE, netIE = self.models['netSE'], self.models['netIE']
        total_losses = self.initialize_validate_losses()
        topk = (1, 5)
        total_evaluations = {
            'accuracy_{}'.format(k): AverageMeter()
            for k in topk
        }
        total_sketch_states,total_image_states, total_sketch_classes, total_image_classes = [], [], [], []
        # Record the sketch states[batch, rel_feat_dim]
        with torch.no_grad():
            start_time = time.time()
            for data_i, batch_data in enumerate(self.datas['val_loader']):
                end_time = time.time()
                self.data_time = end_time - start_time

                batch_data = [
                    tensor.to(device=self.gpu_ids['netSE'],
                              dtype=torch.float32) for tensor in batch_data
                ]
                strokes, pos_images, neg_images, segments, length_masks, targets = batch_data

                # Size Checking: input_states[batch, seq_len, 5], length_mask[batch,seq_len], input_mask[batch, seq_len], target[batch]
                segments = segments.to(dtype=torch.long)
                targets = targets.to(dtype=torch.long)
                # Sketch feature
                output_states = netSE(strokes,
                                      length_masks,
                                      segments=segments,
                                      head_mask=None,
                                      **self.running_paras)

                sketch_states = self.models['sbir'](
                    output_states)  # [batch, rel_feat_dim]
                sketch_states = F.normalize(sketch_states)
                total_sketch_states.append(sketch_states)
                total_sketch_classes.append(targets)
                # Size Checking: output_states, pooled_output[[batch, seq_len, 6*M+3],[]]

                self.run_time = time.time() - end_time
                self.batch_time = self.data_time + self.run_time

        image_loader = self.datas['val_loader'].dataset.get_image_loader(
            shuffle=False)
        # Record the Image states[batch, rel_feat_dim]
        with torch.no_grad():
            start_time = time.time()
            for data_i, batch_data in enumerate(image_loader):
                end_time = time.time()
                self.data_time = end_time - start_time

                batch_data = [
                    tensor.to(device=self.gpu_ids['netSE'],
                              dtype=torch.float32) for tensor in batch_data
                ]
                images, targets = batch_data

                targets = targets.to(dtype=torch.long)
                # Sketch feature
                image_states = netIE(images)
                image_states = F.normalize(image_states)
                total_image_states.append(image_states)
                total_image_classes.append(targets)

                self.run_time = time.time() - end_time
                self.batch_time = self.data_time + self.run_time

        sbir_evaluations = {}
        sbir_evaluations = self.retrieval_evaluation(
            torch.cat(total_sketch_states, dim=0),
            torch.cat(total_sketch_classes, dim=0),
            torch.cat(total_image_states, dim=0),
            torch.cat(total_image_classes, dim=0),
            topk=topk)
        log_losses = {
            key: torch.tensor(total_losses[key].avg)
            for key in total_losses
        }
        log_evaluations = sbir_evaluations
        if self.best_metric <= log_evaluations['retrieval_1']:
            self.best_metirc = log_evaluations['retrieval_1']
            self.best_t = self.counters['t']
        # Update Logs
        self.update_log('val', log_losses, log_evaluations)