Пример #1
0
    def test_step(self, batch: Tuple[Tensor, Tensor],
                  batch_idx: int) -> Tensor:
        """Save predictions on a small set of images from the generator."""
        # Manually put the model into training mode
        # When not in training mode, instance normalization layers
        # seems to be disabled, which leads to bad colourization
        self.train()

        make_image_grid(self, batch, self.config)

        # Don't save over existing file
        count = 0
        while os.path.exists(self.config.result_path + f'/{count}.png'):
            count += 1
        plt.savefig(self.config.result_path + f'/{count}.png')

        # Close figures to prevent too much memory usage
        plt.close('all')
Пример #2
0
 def log_image_grid(self, tag, ngrid, images, step):
     """Log a list of grid of images."""
     
     grid = make_image_grid(images, ngrid)
     img_summaries = []
     for nr, img in enumerate(grid):
         # Write the image to a string
         s = BytesIO()
         scipy.misc.toimage(img).save(s, format="png")
         # Create an Image object
         img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1])
         # Create a Summary value
         img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr), image=img_sum))
     # Create and write Summary
     summary = tf.Summary(value=img_summaries)
     self.writer.add_summary(summary, step)
Пример #3
0
 def evaluate_swa(self, test,plot_samples_images=False):
     self.network.eval() # set network to evaluation mode
     test_total = 0
     test_correct = 0
     test_loader = torch.utils.data.DataLoader(test,128,shuffle=False)
     with torch.no_grad():
         with tqdm.tqdm(total = len(test_loader)) as test_pbar:
             test_accuracy = 0
             for batch_idx, (x_test, y_test) in enumerate(test_loader):
                 outputs = self.network_swa(x_test.to('cuda'))
                 test_correct += self.score(outputs,y_test.to('cuda'))
                 test_total += len(y_test)
                 test_accuracy = test_correct/test_total
                 test_pbar.set_description('[swa test - Accuracy %.2f]' % (test_accuracy))
                 test_pbar.update(1)
             if plot_samples_images:
                 random_4 = torch.randperm(len(x_test))[:4]
                 fig = utils.make_image_grid(x_test[random_4],outputs[random_4],y_test[random_4])
                 return test_accuracy,test_correct,test_total, fig
             else:
                 return test_accuracy,test_correct,test_total
Пример #4
0
 def add_image_grid(self, index, nrow, x, niter):
     grid = make_image_grid(x, nrow)
     self.writer.add_image(index, grid, niter)
Пример #5
0
 def add_image_grid(self, index, ngrid, x, niter):
     grid = utils.make_image_grid(x, ngrid)
     self.writer.add_image(index, grid, niter)
Пример #6
0
 def add_image_grid(self, index, ngrid, x, niter, logtype):
     grid = utils.make_image_grid(x, ngrid)
     self.writer[logtype].add_image(index, grid, niter)
Пример #7
0
        label = label[0].cuda()
        label = Variable(label)
        inputs = Variable(img)

        feats = res50(inputs)
        output = seg(feats)

        seg.zero_grad()
        res50.zero_grad()
        loss = criterion(output, label)
        loss.backward()
        optimizer_feat.step()
        optimizer_seg.step()

        ## see
        inputs = make_image_grid(img, mean, std)
        label = make_label_grid(label.data)
        label = Colorize()(label).type(torch.FloatTensor)
        output = make_label_grid(torch.max(output, dim=1)[1].data)
        output = Colorize()(output).type(torch.FloatTensor)
        writer.add_image('image', inputs, i)
        writer.add_image('label', label, i)
        writer.add_image('pred', output, i)
        writer.add_scalar('loss', loss.data[0], i)
        metric = FCN_metric(output, label)
        writer.add_scalar('mIU', metric['MIU'], i)
        #         writer.add_scalar('iou', compute_mean_iou(output,label),i)
        if i % 100 is 0:
            print(output.shape)
            # plt.imshow(np.asarray(output))
#             plt.show()
Пример #8
0
    def solve(self, epoch):
        '''
        solve for 1 epoch.
        Args:
            xr: raw, target model image.
            xc: clean, relavant product image.
            xi: clean, irrelavant product image.
        '''
        batch_timer = AverageMeter()
        data_timer = AverageMeter()
        since = time.time()
        bar = Bar('[PixelDtGan] Training ...', max=len(self.dataloader))

        for batch_index, x in enumerate(self.dataloader):
            self.globalIter = self.globalIter + 1
            # measure data loading time
            data_timer.update(time.time() - since)

            # convert to cuda, variable
            xr = x['raw']
            xc = x['clean']
            xi = x['irre']
            xr = __to_var__(__cuda__(xr))
            xc = __to_var__(__cuda__(xc))
            xi = __to_var__(__cuda__(xi))

            # xr_test for test with fixed input.
            if self.globalIter == 1:
                xr_test = xr.clone()
                xc_test = xc.clone()

            # zero gradients.
            self.gen.zero_grad()
            self.dis.zero_grad()
            self.dom.zero_grad()
            '''update discriminator. (dis, dom)
            '''
            since = time.time()
            # train dis (real/fake)
            dl_xc = self.dis(xc)  # real, relavant
            dl_xi = self.dis(xi)  # real, irrelavant
            xc_tilde = self.gen(xr)
            dl_xc_tilde = self.dis(xc_tilde.detach())  # fake (detach)
            real_label = dl_xc.clone().fill_(1).detach()
            fake_label = dl_xc.clone().fill_(0).detach()
            loss_dis = self.mse(dl_xc, real_label) + self.mse(
                dl_xi, real_label) + self.mse(dl_xc_tilde, fake_label)

            # train dom (associated-pair/non-associated-pair)
            xp_ass = torch.cat((xr, xc), dim=1)
            xp_noass = torch.cat((xr, xi), dim=1)
            xp_tilde = torch.cat((xr, xc_tilde.detach()), dim=1)
            dl_xp_ass = self.dom(xp_ass)
            dl_xp_noass = self.dom(xp_noass)
            dl_xp_tilde = self.dom(xp_tilde)
            loss_dom = self.mse(dl_xp_ass, real_label) + self.mse(
                dl_xp_noass, fake_label) + self.mse(dl_xp_tilde, fake_label)
            loss_D_total = 0.5 * (loss_dis + loss_dom)
            loss_D_total.backward()
            self.opt_dis.step()
            self.opt_dom.step()
            '''update generator. (gen)
            '''
            # train gen (real/fake)
            gl_xc_tilde = self.dis(xc_tilde)
            gl_xp_tilde = self.dom(xp_tilde)
            loss_gen = self.mse(gl_xc_tilde, real_label) + self.mse(
                gl_xp_tilde, real_label)
            loss_gen.backward()
            self.opt_gen.step()

            # measure batch process time
            batch_timer.update(time.time() - since)

            # print log
            log_msg = '\n[Epoch:{EPOCH:}][Iter:{ITER:}][lr:{LR:}] Loss_dis:{LOSS_DIS:.3f} | Loss_dom:{LOSS_DOM:.3f} | Loss_gen:{LOSS_GEN:.3f} | eta:(data:{DATA_TIME:.3f}),(batch:{BATCH_TIME:.3f}),(total:{TOTAL_TIME:})' \
            .format(
                EPOCH=epoch+1,
                ITER=batch_index+1,
                LR=self.config.lr,
                LOSS_DIS=loss_dis.data.sum(),
                LOSS_DOM=loss_dom.data.sum(),
                LOSS_GEN=loss_gen.data.sum(),
                DATA_TIME=data_timer.val,
                BATCH_TIME=batch_timer.val,
                TOTAL_TIME=bar.elapsed_td)
            print(log_msg)
            bar.next()

            # visualization
            if self.config.use_tensorboard:
                self.tb.add_scalar('data/loss_dis', float(loss_dis.data.cpu()),
                                   self.globalIter)
                self.tb.add_scalar('data/loss_dom', float(loss_dom.data.cpu()),
                                   self.globalIter)
                self.tb.add_scalar('data/loss_gen', float(loss_gen.data.cpu()),
                                   self.globalIter)

                if self.globalIter % self.config.save_image_every == 0:
                    xall = torch.cat((xc_tilde, xc, xr), dim=0)
                    xall = adjust_pixel_range(xall,
                                              range_from=[-1, 1],
                                              range_to=[0, 1])
                    self.tb.add_image_grid('grid/output', 8,
                                           xall.cpu().data, self.globalIter)

                    xc_tilde_test = self.gen(xr_test)
                    xall_test = torch.cat((xc_tilde_test, xc_test, xr_test),
                                          dim=0)
                    xall_test = adjust_pixel_range(xall_test,
                                                   range_from=[-1, 1],
                                                   range_to=[0, 1])
                    self.tb.add_image_grid('grid/output_fixed', 8,
                                           xall_test.cpu().data,
                                           self.globalIter)

                    # save image as png.
                    mkdir_p(os.path.join(self.prefix, 'image'))
                    image = make_image_grid(xc_tilde_test.cpu().data, 5)
                    image = F.upsample(image.unsqueeze(0),
                                       size=(800, 800),
                                       mode='bilinear').squeeze()
                    filename = 'Epoch_{}_Iter{}.png'.format(
                        self.epoch, self.globalIter)
                    vutils.save_image(image,
                                      os.path.join(self.prefix, 'image',
                                                   filename),
                                      nrow=1)

        bar.finish()