Exemple #1
0
def get_all_comparison_metrics(denoised,
                               source,
                               noisy=None,
                               return_title_string=False,
                               clamp=True):

    metrics = {}
    metrics['psnr'] = np.zeros(len(denoised))
    metrics['ssim'] = np.zeros(len(denoised))
    if noisy is not None:
        metrics['psnr_delta'] = np.zeros(len(denoised))
        metrics['ssim_delta'] = np.zeros(len(denoised))

    if clamp:
        denoised = torch.clamp(denoised, 0.0, 1.0)

    metrics['psnr'] = psnr(source, denoised)
    metrics['ssim'] = ssim(source, denoised)

    if noisy is not None:
        metrics['psnr_delta'] = metrics['psnr'] - psnr(source, noisy)
        metrics['ssim_delta'] = metrics['ssim'] - ssim(source, noisy)

    if return_title_string:
        return convert_dict_to_string(metrics)
    else:
        return metrics
Exemple #2
0
    def _validate(self, validate, metrics, transform=None):
        """Validation process"""
        for cover, _ in tqdm(validate, disable=not self.verbose):
            gc.collect()
            cover = cover.to(self.device)
            generated, payload, decoded_g, decoded_t = self._encode_decode(
                cover, quantize=True, transform=transform)
            encoder_mse, decoder_loss_g, decoder_acc_g = self._coding_scores(
                cover, generated, payload, decoded_g)
            _, decoder_loss_t, decoder_acc_t = self._coding_scores(cover,
                generated, payload, decoded_t)
            generated_score = self._critic(generated)
            cover_score = self._critic(cover)

            perceptual_loss = None
            with torch.no_grad():
                phi_generated = self.perceptual_loss_fc.forward(generated).squeeze() # [batch size, 2048]
                phi_cover = self.perceptual_loss_fc.forward(cover).squeeze()
                perceptual_loss = torch.mean(torch.norm(phi_cover - phi_generated,
                    p=2, dim=1), dim=0)

            metrics['val.perceptual_loss'].append(perceptual_loss.item())
            metrics['val.encoder_mse'].append(encoder_mse.item())
            metrics['val.decoder_loss_g'].append(decoder_loss_g.item())
            metrics['val.decoder_loss_t'].append(decoder_loss_t.item())
            metrics['val.decoder_acc_g'].append(decoder_acc_g.item())
            metrics['val.decoder_acc_t'].append(decoder_acc_t.item())
            metrics['val.cover_score'].append(cover_score.item())
            metrics['val.generated_score'].append(generated_score.item())
            metrics['val.ssim'].append(ssim(cover, generated).item())
            metrics['val.psnr'].append(10 * torch.log10(4 / encoder_mse).item())
            metrics['val.bpp_g'].append(self.data_depth * (2 * decoder_acc_g.item() - 1))
            metrics['val.bpp_t'].append(self.data_depth * (2 * decoder_acc_t.item() - 1))
Exemple #3
0
def get_metrics_reduced(img1, img2):
    # input: img1 {the pan-sharpened image}, img2 {the ground-truth image}
    # return: (larger better) psnr, ssim, scc, (smaller better) sam, ergas
    m1 = psnr_loss(img1, img2, 1.)
    m2 = ssim(img1, img2, 11, 'mean', 1.)
    m3 = cc(img1, img2)
    m4 = sam(img1, img2)
    m5 = ergas(img1, img2)
    return [m1.item(), m2.item(), m3.item(), m4.item(), m5.item()]
Exemple #4
0
def get_ssim(fake, real):
    cpu = torch.device("cpu")

    ssim_list = []
    for i in range(len(fake)):
        np_fake = fake[i].to(cpu).detach().clone().numpy().transpose([1, 2, 0])
        np_real = real[i].to(cpu).detach().clone().numpy().transpose([1, 2, 0])
        ssim_list.append(ssim(np_fake, np_real))
    return statistics.mean(ssim_list)
Exemple #5
0
    def loss(self, input_seq, target):
        output = self(input_seq)

        l2_loss = F.mse_loss(output * 255, target * 255)
        l1_loss = F.l1_loss(output * 255, target * 255)
        # psnr_ = psnr(output, target)
        psnr_ = 10 * log10(255 / l2_loss)
        ssim_ = ssim(output, target)
        return l1_loss, l2_loss, output, psnr_, ssim_
Exemple #6
0
    def recon_criterion_rmse(self, input, target, mask, denorm=True):
        if (denorm):
            input = (input * 0.5 + 0.5)
            target = (target * 0.5 + 0.5)
        out = 0
        psnr_ = 0
        ssim_ = 0
        if (len(input.shape) == 3):
            tmp = torch.sum(
                (torch.mul(input, mask) - torch.mul(target, mask))**2)
            tmp /= torch.sum(mask)
            tmp = tmp**0.5
            psnr = 20 * torch.log10(1 / tmp)
            img1 = torch.mul(input, mask) + torch.mul(target, 1 - mask)
            img1 = torch.unsqueeze(img1, dim=0)
            img2 = torch.unsqueeze(target, dim=0)
            ssim_loss = ssim(img1, img2)
            #ssim_loss = pytorch_ssim.SSIM(window_size=11)
            return tmp.item(), psnr.item(), ssim_loss.item()
        else:
            for i in range(len(input)):
                tmp = torch.sum((torch.mul(input[i], mask[i]) -
                                 torch.mul(target[i], mask[i]))**2)
                tmp /= torch.sum(mask[i])
                tmp = tmp**0.5
                out += tmp
                psnr_ += 20 * torch.log10(1 / tmp)

                img1 = torch.mul(input[i], mask[i]) + torch.mul(
                    target[i], 1 - mask[i])
                img1 = torch.unsqueeze(img1, dim=0)
                img2 = torch.unsqueeze(target[i], dim=0)
                ssim_ += ssim(img1, img2)

            return (out / len(input)).item(), (psnr_ / len(input)).item(), (
                ssim_ / len(input)).item()
Exemple #7
0
    def fit(self, X, Y, X_test, Y_test, layers, 
            max_iterations=10, 
            batch_size=1024, 
            learning_rate=0.01,
            output_each_iter=None):
        
        self.layers = layers
        self.max_iterations = max_iterations
        self.batch_size = batch_size
        self.learning_rate = learning_rate        
        self.batch_costs = []
        self.train_costs = []
        self.val_costs = []
        self.ssim_costs = []
        
        for iteration in range(self.max_iterations):
            for X_mini, Y_mini in self.create_minibatches(X, Y):
                activations = self.forward_step(X_mini)

                batch_cost = layers[-1].get_cost(activations[-1], Y_mini)
                self.batch_costs.append(batch_cost)
                
                param_grads = self.backward_step(activations, Y_mini)

                self.update_params(param_grads, iteration)

            activations = self.forward_step(X)
            train_cost = layers[-1].get_cost(activations[-1], X)
            self.train_costs.append(train_cost)

            activations = self.forward_step(X_test)
            validation_cost = layers[-1].get_cost(activations[-1], X_test)

            ssim_cost = np.mean([ssim(x_noisy, x_clean) for x_noisy, x_clean in zip(X_test[0:1000], activations[-1][0:1000])])
            self.ssim_costs.append(ssim_cost)
            
            self.val_costs.append(validation_cost)
            if output_each_iter and iteration % output_each_iter == 0:
                print(f"Iteration: {iteration}; Train loss: {train_cost}; Validation loss: {validation_cost};")
Exemple #8
0
    def test(self):
        """Translate images using trained TCN."""
        from sklearn.metrics import accuracy_score

        # Load the trained generator.
        self.restore_model(self.test_iters)
        self.restore_cls_model()

        l1_rec = 0.
        ssim_rec = 0.
        l1_test = 0.
        ssim_test = 0.
        style_acc = 0.
        char_acc = 0.
        style_acc_rec = 0.
        char_acc_rec = 0.
        with torch.no_grad():
            for i, (x_real, x_style, x_char, y_trg, y_char) in enumerate(self.data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                y_trg = y_trg.to(self.device)
                x_char = x_char.to(self.device)
                x_char_onehot = self.label2onehot(x_char, self.char_cnt)
                y_char = y_char.to(self.device)
                y_char_onehot = self.label2onehot(y_char, self.char_cnt)

                # Translate images.
                fake_list = [x_real, y_trg]
                style_enc, char_enc, _, _ = self.E(x_real)
                x_fake = self.G(x_char_onehot, style_enc, char_enc, x_char_onehot)
                fake_list.append(x_fake)
                _, _, style_cls_rec, char_cls_rec = self.C(x_fake)
                y_fake = self.G(x_char_onehot, style_enc, char_enc, y_char_onehot)
                fake_list.append(y_fake)
                _, _, style_cls, char_cls = self.C(y_fake)

                loss_l1_rec = torch.mean(torch.abs(x_real - x_fake))
                loss_ssim_rec = utils.ssim(x_real, x_fake)
                loss_l1 = torch.mean(torch.abs(y_trg - y_fake))
                loss_ssim = utils.ssim(y_trg, y_fake)

                acc_style_rec = accuracy_score(x_style.cpu().numpy(),
                                               torch.max(style_cls_rec, 1)[1].cpu().numpy())
                acc_char_rec = accuracy_score(x_char.cpu().numpy(),
                                              torch.max(char_cls_rec, 1)[1].cpu().numpy())
                acc_style = accuracy_score(x_style.cpu().numpy(),
                                           torch.max(style_cls, 1)[1].cpu().numpy())
                acc_char = accuracy_score(y_char.cpu().numpy(),
                                          torch.max(char_cls, 1)[1].cpu().numpy())

                l1_rec += loss_l1_rec.item()
                ssim_rec += loss_ssim_rec.item()
                l1_test += loss_l1.item()
                ssim_test += loss_ssim.item()
                style_acc_rec += acc_style_rec
                char_acc_rec += acc_char_rec
                style_acc += acc_style
                char_acc += acc_char

                # Save the translated images.
                x_concat = torch.cat(fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(x_concat.data.cpu(), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
            print('[Rec L1] : {} [Rec SSIM] : {} [Rec Style Acc] : {} [Rec Char Acc] : {} \
                   [TC L1] : {} [TC SSIM] : {}, [Style Acc] : {} [Char Acc] : {}'.format(
                   l1_rec/(i+1), ssim_rec/(i+1), style_acc_rec/(i+1), char_acc_rec/(i+1),
                    l1_test/(i+1), ssim_test/(i+1), style_acc/(i+1), char_acc/(i+1)))
Exemple #9
0
    def train(self):
        """Train TCN."""
        # Start training from scratch or resume training.
        start_iters = 0
        E_path = os.path.join(self.model_save_dir, 'E.ckpt'.format(self.enc_iters))
        if not os.path.isfile(E_path):
            self.pretrain()
        pretrained_E_dict = torch.load(E_path, map_location=lambda storage, loc: storage)
        E_dict = self.E.state_dict()
        pretrained_E_dict = {k: v for k, v in pretrained_E_dict.items() if k in E_dict}
        E_dict.update(pretrained_E_dict)
        self.E.load_state_dict(E_dict)
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Fetch fixed inputs for debugging.
        data_iter = iter(self.data_loader)
        x_fixed, x_fixed_style, x_fixed_char, y_fixed, y_fixed_char = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        y_fixed = y_fixed.to(self.device)
        y_fixed_char = y_fixed_char.to(self.device)
        c_fixed_list = [(self.label2onehot(x_fixed_char, self.char_cnt),
                         self.label2onehot(y_fixed_char, self.char_cnt))]

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, x_style, x_char, y_trg, y_char = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, x_style, x_char, y_trg, y_char = next(data_iter)

            batch_size = x_real.size(0)
            # Generate real labels
            x_real = x_real.to(self.device)
            x_style= x_style.to(self.device)
            x_char = x_char.to(self.device)
            x_char_onehot = self.label2onehot(x_char, self.char_cnt)
            # Character transfer. keep style. and thats' character index

            y_trg  = y_trg.to(self.device)
            y_char = y_char.to(self.device)
            y_char_onehot = self.label2onehot(y_char, self.char_cnt)
            # Style transfer. keep character. and thats' style index

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            out_src, out_style, out_char = self.D(y_trg)
            d_loss_real  = torch.mean((out_src - 1) ** 2)
            d_loss_style = self.classification_loss(out_style, x_style)
            d_loss_char  = self.classification_loss(out_char, y_char)
            d_acc_char  = self.classification_acc(out_char, y_char)

            # Compute loss with fake images.
            style_enc, char_enc, _, _  = self.E(x_real)
            y_fake = self.G(x_char_onehot, style_enc, char_enc, y_char_onehot)
            fake_src, _, _ = self.D(y_fake.detach())
            d_loss_fake = torch.mean(fake_src ** 2)

            # Compute loss for gradient penalty.
            alpha = torch.rand(y_trg.size(0), 1, 1, 1).to(self.device)
            y_hat = (alpha * y_trg.data + (1 - alpha) * y_fake.data).requires_grad_(True)
            gp_src, _, _ = self.D(y_hat)
            d_loss_gp = self.gradient_penalty(gp_src, y_hat)

            # Backward and optimize.
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * (d_loss_style + d_loss_char)\
                                               + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_style'] = d_loss_style.item()
            loss['D/loss_char'] = d_loss_char.item()
            loss['D/acc_char'] = d_acc_char.item()
            loss['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            if (i+1) % self.n_critic == 0:
                # Original-to-target domain.
                style_enc, char_enc, _, _ = self.E(x_real)
                y_fake = self.G(x_char_onehot, style_enc, char_enc, y_char_onehot)
                out_src, out_style, out_char = self.D(y_fake)

                g_loss_fake  = torch.mean((out_src - 1) ** 2)
                g_loss_style = self.classification_loss(out_style, x_style)
                g_loss_char  = self.classification_loss(out_char, y_char)
                g_acc_style = self.classification_acc(out_style, x_style)
                g_acc_char  = self.classification_acc(out_char, y_char)

                # Training G to 'y_fake' and 'y_trg' are similar. L1 loss
                g_loss_l1 = torch.mean(torch.abs(y_trg - y_fake))

                # Compute Structural similarity measure of the Generator
                g_loss_ssim = utils.ssim(y_trg, y_fake)

                # Target-to-original domain.
                style_fenc, char_fenc, _, _  = self.E(y_fake)
                x_reconst = self.G(y_char_onehot, style_fenc, char_fenc, x_char_onehot)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # Reconstruct Perceptual Loss
                style_renc, char_renc, _, _ = self.E(x_reconst)
                g_loss_percept = torch.mean((style_enc - style_renc) ** 2) +\
                                 torch.mean((char_enc - char_renc) ** 2)

                x_fake = self.G(x_char_onehot, style_enc, char_enc, x_char_onehot)
                g_loss_id = torch.mean(torch.abs(x_real - x_fake))

                # Backward and optimize.
                g_loss = g_loss_fake + g_loss_style \
                                     + self.lambda_cls * (g_loss_char) \
                                     + self.lambda_rec * (g_loss_rec + g_loss_percept + g_loss_id)
                                     + self.lambda_ssim* (g_loss_l1 - g_loss_ssim)
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_style'] = g_loss_style.item()
                loss['G/loss_char'] = g_loss_char.item()
                loss['G/acc_char'] = g_acc_char.item()
                loss['G/loss_l1'] = g_loss_l1.item()
                loss['G/loss_ssim'] = g_loss_ssim.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_per'] = g_loss_percept.item()
                loss['G/loss_id'] = g_loss_id.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_real, y_trg, y_fake, x_fixed, y_fixed]
                    style_fixed_enc, char_fixed_enc, _, _ = self.E(x_fixed)
                    for (c_ffixed, c_tfixed) in c_fixed_list:
                        x_fake_list.append(self.G(c_ffixed, style_fixed_enc, char_fixed_enc, c_tfixed))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(x_concat.data.cpu(), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                E_path = os.path.join(self.model_save_dir, '{}-E.ckpt'.format(i+1))
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.E.state_dict(), E_path)
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
Exemple #10
0
def main():
  

    method = Methods(
        scale=scale,
        image_size=image_size,
        label_size=label_size,
        color_dim=color_dim,
        is_training=False,
       )

    
    X_pre_test, X_test, Y_test = load_test(scale=scale, test_folder=test_folder, color_dim=color_dim)
               
    predicted_list = []
    
    weight_filename = '../srcnn/S3Models/RDN/L1/model0040.hdf5'
    model = method.RDN()
    model.load_weights(weight_filename)
    

              
    for img in X_pre_test:
        img = img.astype('float')
        h1 = img.shape[0]
        w1 = img.shape[1] 
        h0 = np.int(h1/scale)
        w0 = np.int(w1/scale)  
        img = img[:,:,0]
        #img = img/255 
        test_sample = img.reshape(1,img.shape[0],img.shape[1],color_dim)
        time_start = time.clock()
        predicted = model.predict(test_sample)
        time_elapsed = (time.clock() - time_start)
        print('Time:', time_elapsed)
        predicted_list.append(predicted.reshape(predicted.shape[1],predicted.shape[2],1))
    
    n_img = len(predicted_list)
    dirname = 'result'
    psnr_bic = 0
    psnr_cnn = 0
    ssim_bic = 0
    ssim_cnn = 0
    
    dirname2 = './TestDatasets/'+test_folder
    img_list = os.listdir(dirname2)
    for i in range(n_img):
        imgname = 'image{:02}'.format(i)
        low_res =  X_pre_test[i]
        bic = np.float32(X_test[i])
        gnd = np.float32(Y_test[i])
        cnn = predicted_list[i]
        cnn = np.float32((cnn*1))
        cnn = np.clip(cnn, 0, 255)
        
        cnn = cnn[scale:-scale, scale:-scale, :]
        bic = bic[scale:-scale, scale:-scale, :]
        gnd = gnd[scale:-scale, scale:-scale, :]
#    
        name = os.path.splitext(os.path.basename(img_list[i]))[0]
       
        #cv2.imwrite(os.path.join(dirname,imgname+'_original.bmp'), low_res)
        cv2.imwrite(os.path.join(dirname,name+'_bic.bmp'), bic)
        cv2.imwrite(os.path.join(dirname, name+'_gnd.bmp'), gnd)
        cv2.imwrite(os.path.join(dirname,name+'_cnn.bmp'), cnn)
 
        bic_ps =  psnr(gnd, bic)
        cnn_ps = psnr(gnd, cnn)
        bic_ss  = ssim(gnd, bic)
        cnn_ss  = ssim(gnd, cnn)
        print(name+' bic:',bic_ps)
        print(name+' cnn',cnn_ps)       
        
        psnr_bic += bic_ps
        psnr_cnn += cnn_ps
        ssim_bic += bic_ss
        ssim_cnn += cnn_ss
#     
    #####################
    print('psnr_bic:',psnr_bic/n_img)
    print('psnr_cnn:',psnr_cnn/n_img)
    print('ssim_bic:',ssim_bic/n_img)
    print('ssim_cnn:',ssim_cnn/n_img)
		for epoch in range(args.num_iters):
#			train_loader,val_loader,test_loader = read_data(args.batch_size)
			network.train()
			network_up.train()
			for idx, x in enumerate(train_loader):
				img64 = x['img64'].cuda()
				img128 = x['img128'].cuda()
				low_img64 = nn.functional.interpolate(img128,scale_factor=0.5, mode='bilinear', align_corners = False)
				imgname = x['img_name']
				optimizer.zero_grad()
				g_img64_res = network(img64)
				g_img64 = img64-g_img64_res
				l2_loss = ((255*(low_img64-g_img64))**2).mean()
				l1_loss = (abs(255*(low_img64-g_img64))).mean()
				rmse_loss = rmse(low_img64,g_img64)
				ssim_loss = ssim(low_img64,g_img64)
				# tv_losss = tv_loss(255*img128,255*g_img128)
				# dloss = bce_loss(d(g_img128,img64),true_crit)
				loss = l2_loss + l1_loss  - args.l1*ssim_loss
				loss.backward(retain_graph=True)
				optimizer.step()

				if idx%10 ==0:
					print("LOW TRAINING {} {}: RMSE_LOSS:{} SSIM:{} L1:{} tv:{} TOTAL:{} ".format(epoch,idx,
						(rmse_loss.detach().cpu().numpy()),
						ssim_loss.detach().cpu().numpy(),
						l1_loss.detach().cpu().numpy(),
						0,#tv_losss.detach().cpu().numpy(),
						loss.detach().cpu().numpy()))

Exemple #12
0
def train():
    
    model = Basic_model.MS_LapSRN_model(D = D, R = R)
    
    model.compile(optimizer = 'adam', loss = 'mean_absolute_error', loss_weights = {'add_8':1-alpha, 'add_9':alpha})
    print(model.summary())
    
    #model.load_weights('checkpoint/D{}R{}_DIV2K_alpha:{}.h5'.format(D, R, alpha))
    
    data, label_x2, label_x4 = ps.read_training_data('training_sample/train_DIV2K_scale4_RGB.h5')
    val_data, val_label_x2, val_label_x4 = ps.read_training_data('training_sample/val_DIV2K_scale4_RGB.h5')
    
    label = [label_x2, label_x4]
    val_label = [val_label_x2, val_label_x4]
    
    PATH_image = '../../Dataset/MS_LapSRN/Test/{}/'.format(test_set)
    
    names_image = os.listdir(PATH_image)
    names_image = sorted(names_image)
    
    nums = len(names_image)
    
    count = 0
    global total_history
    
    checkpoint_filepath = 'checkpoint/MS_LapSRN_DIV2K_alpha:{}_Wls.h5'.format(alpha)
    checkpoint_callbacks = [ModelCheckpoint(filepath = checkpoint_filepath, monitor = 'val_loss', verbose = 1, mode = 'min', 
                                            save_best_only = True), LearningRateScheduler(adjust_learning_rate)]
    
    for i in range(0, 2000):
        
        history = model.fit(x = data, y = label, batch_size = 16, epochs = 2, verbose = 1,
                            callbacks = checkpoint_callbacks, validation_data = (val_data, val_label), shuffle = True)
        
        count += 1
        
        psnr_model_x2 = []
        psnr_bicubic_x2 = []
        
        psnr_model_x4 = []
        psnr_bicubic_x4 = []
        
        ssim_model_x2 = []
        ssim_bicubic_x2 = []
        
        ssim_model_x4 = []
        ssim_bicubic_x4 = []
        
        
        for i in range(nums):
            
            mat_image = io.loadmat(PATH_image + names_image[i])
            
            input_img = mat_image['im_input_rgb']
            
            hr_img_x2 = mat_image['im_hr_x2_rgb']
            bicubic_img_x2 = mat_image['im_bicubic_x2_rgb']
            
            hr_img_x4 = mat_image['im_hr_x4_rgb']
            bicubic_img_x4 = mat_image['im_bicubic_x4_rgb']
            
            shape_input = input_img.shape
            shape_x2 = hr_img_x2.shape
            shape_x4 = hr_img_x4.shape
            
            input_RGB = np.zeros([1, shape_input[0], shape_input[1], 3])
            input_RGB[0, :, :, :] = input_img / 255

            pre = model.predict(input_RGB, batch_size = 1)
            pre_x2 = pre[0]
            pre_x4 = pre[1]
            
            pre_x2 = pre_x2 * 255
            pre_x4 = pre_x4 * 255
            
            pre_x2[pre_x2[:] > 255] = 255
            pre_x2[pre_x2[:] < 0] = 0
            pre_x4[pre_x4[:] > 255] = 255
            pre_x4[pre_x4[:] < 0] = 0
            

            output_img_x2 = np.zeros([shape_x2[0], shape_x2[1], 3])
            output_img_x2[:, :, 2] = pre_x2[0, :, :, 0]
            output_img_x2[:, :, 1] = pre_x2[0, :, :, 1]
            output_img_x2[:, :, 0] = pre_x2[0, :, :, 2]
            
            hr_img_x2_r = hr_img_x2[:, :, 0]
            hr_img_x2_g = hr_img_x2[:, :, 1]
            hr_img_x2_b = hr_img_x2[:, :, 2]
            
            output_img_x2_r = output_img_x2[:, :, 2]
            output_img_x2_g = output_img_x2[:, :, 1]
            output_img_x2_b = output_img_x2[:, :, 0]
            
            bicubic_img_x2_r = bicubic_img_x2[:, :, 0]
            bicubic_img_x2_g = bicubic_img_x2[:, :, 1]
            bicubic_img_x2_b = bicubic_img_x2[:, :, 2]
            
            hr_img_x2_Y = 16 + (65.738 * hr_img_x2_r + 129.057 * hr_img_x2_g + 25.064 * hr_img_x2_b) / 255
            output_img_x2_Y = 16 + (65.738 * output_img_x2_r + 129.057 * output_img_x2_g + 25.064 * output_img_x2_b) / 255
            bicubic_img_x2_Y = 16 + (65.738 * bicubic_img_x2_r + 129.057 * bicubic_img_x2_g + 25.064 * bicubic_img_x2_b) / 255
            
            output_img_x4 = np.zeros([shape_x4[0], shape_x4[1], 3])
            output_img_x4[:, :, 2] = pre_x4[0, :, :, 0]
            output_img_x4[:, :, 1] = pre_x4[0, :, :, 1]
            output_img_x4[:, :, 0] = pre_x4[0, :, :, 2]

            hr_img_x4_r = hr_img_x4[:, :, 0]
            hr_img_x4_g = hr_img_x4[:, :, 1]
            hr_img_x4_b = hr_img_x4[:, :, 2]
            
            output_img_x4_r = output_img_x4[:, :, 2]
            output_img_x4_g = output_img_x4[:, :, 1]
            output_img_x4_b = output_img_x4[:, :, 0]
            
            bicubic_img_x4_r = bicubic_img_x4[:, :, 0]
            bicubic_img_x4_g = bicubic_img_x4[:, :, 1]
            bicubic_img_x4_b = bicubic_img_x4[:, :, 2]
            
            hr_img_x4_Y = 16 + (65.738 * hr_img_x4_r + 129.057 * hr_img_x4_g + 25.064 * hr_img_x4_b) / 255
            output_img_x4_Y = 16 + (65.738 * output_img_x4_r + 129.057 * output_img_x4_g + 25.064 * output_img_x4_b) / 255 
            bicubic_img_x4_Y = 16 + (65.738 * bicubic_img_x4_r + 129.057 * bicubic_img_x4_g + 25.064 * bicubic_img_x4_b) / 255
            
            
            # YCrCb Channel에서 Y에 대해 PSNR 측정 시
            hr_img_x2_measure = hr_img_x2_Y[conv_side:-conv_side, conv_side:-conv_side]
            output_img_x2_measure = output_img_x2_Y[conv_side:-conv_side, conv_side:-conv_side]
            bicubic_img_x2_measure = bicubic_img_x2_Y[conv_side:-conv_side, conv_side:-conv_side]
            
            psnr_x2 = psnr(output_img_x2_measure, hr_img_x2_measure)
            ssim_x2 = ssim(output_img_x2_measure, hr_img_x2_measure)
            
            psnr_x2_bicubic = psnr(bicubic_img_x2_measure, hr_img_x2_measure)
            ssim_x2_bicubic = ssim(bicubic_img_x2_measure, hr_img_x2_measure)
            
            hr_img_x4_measure = hr_img_x4_Y[conv_side:-conv_side, conv_side:-conv_side]
            output_img_x4_measure = output_img_x4_Y[conv_side:-conv_side, conv_side:-conv_side]
            bicubic_img_x4_measure = bicubic_img_x4_Y[conv_side:-conv_side, conv_side:-conv_side]
            
            psnr_x4 = psnr(output_img_x4_measure, hr_img_x4_measure)
            ssim_x4 = ssim(output_img_x4_measure, hr_img_x4_measure)
            
            psnr_x4_bicubic = psnr(bicubic_img_x4_measure, hr_img_x4_measure)
            ssim_x4_bicubic = ssim(bicubic_img_x4_measure, hr_img_x4_measure)
            
            print(i + 1)
            
            print('Bicubic_x2: ', psnr_x2_bicubic, 'ssim: ', ssim_x2_bicubic)
            print('Model_x2: ', psnr_x2, 'ssim: ', ssim_x2)
            
            print('Bicubic_x4: ', psnr_x4_bicubic, 'ssim: ', ssim_x4_bicubic)
            print('Model_x4: ', psnr_x4, 'ssim: ', ssim_x4)

            
            psnr_bicubic_x2.append(psnr_x2_bicubic)
            ssim_bicubic_x2.append(ssim_x2_bicubic)
            
            psnr_model_x2.append(psnr_x2)
            ssim_model_x2.append(ssim_x2)
            
            psnr_bicubic_x4.append(psnr_x4_bicubic)
            ssim_bicubic_x4.append(ssim_x4_bicubic)
            
            psnr_model_x4.append(psnr_x4)
            ssim_model_x4.append(ssim_x4)
            
            
        psnr_bicubic_x2_final = np.mean(psnr_bicubic_x2)
        ssim_bicubic_x2_final = np.mean(ssim_bicubic_x2)
        
        psnr_model_x2_final = np.mean(psnr_model_x2)
        ssim_model_x2_final = np.mean(ssim_model_x2)
        
        
        psnr_bicubic_x4_final = np.mean(psnr_bicubic_x4)
        ssim_bicubic_x4_final = np.mean(ssim_bicubic_x4)
        
        psnr_model_x4_final = np.mean(psnr_model_x4)
        ssim_model_x4_final = np.mean(ssim_model_x4)
        
        print('Epochs: ', count*2)
        
        print('Bicubic_x2')
        print('PSNR: ', psnr_bicubic_x2_final, 'SSIM: ', ssim_bicubic_x2_final)
        print('Model_x2')
        print('PSNR: ', psnr_model_x2_final, 'SSIM: ', ssim_model_x2_final)
        
        print('Bicubic_x4')
        print('PSNR: ', psnr_bicubic_x4_final, 'SSIM: ', ssim_bicubic_x4_final)
        print('Model_x4')
        print('PSNR: ', psnr_model_x4_final, 'SSIM: ', ssim_model_x4_final)
        
                # Error Graph 그리기
        for key, value in history.history.items():

            total_history[key] = sum([total_history[key], history.history[key]], [])
            
        length = len(total_history['loss'])

        plt.plot(total_history['loss'])
        plt.plot(total_history['val_loss'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        #plt.xlim(9, length)
        plt.ylim(0.015, 0.003)
        plt.legend(['train', 'val'], loc='upper left')
        plt.show()
    val_SSIM = 0
    for i, (in_img, RGBout_img, path) in enumerate(test_loader):
        # To device
        # A is for input image, B is for target image
        in_img = in_img.cuda()
        RGBout_img = RGBout_img.cuda()
        #print(path)

        # Forward propagation
        with torch.no_grad():
            out = generator(in_img)

        # Sample data every iter
        img_list = [out, RGBout_img]
        name_list = ['pred', 'gt']
        utils.save_sample_png(sample_folder = sample_folder, sample_name = '%d' % (i), img_list = img_list, name_list = name_list, pixel_max_cnt = 255)
        
        # PSNR
        val_PSNR_this = utils.psnr(out, RGBout_img, 1) * in_img.shape[0]
        print('The %d-th image PSNR %.4f' % (i, val_PSNR_this))
        val_PSNR = val_PSNR + val_PSNR_this
        # SSIM
        val_SSIM_this = utils.ssim(out, RGBout_img) * in_img.shape[0]
        print('The %d-th image SSIM %.4f' % (i, val_SSIM_this))
        val_SSIM = val_SSIM + val_SSIM_this
        
    val_PSNR = val_PSNR / len(namelist)
    val_SSIM = val_SSIM / len(namelist)
    print('The average PSNR equals to', val_PSNR)
    print('The average SSIM equals to', val_SSIM)
Exemple #14
0
		d_optimizer = torch.optim.Adam(d.parameters(), lr=args.lr)
		network.cuda()
		d.cuda()
		for epoch in range(args.num_iters):
			train_loader,val_loader,test_loader = read_data(args.batch_size)
			network.train()
			for idx, x in enumerate(train_loader):
				img64 = x['img64'].cuda()
				img128 = x['img128'].cuda()
				imgname = x['img_name']
				optimizer.zero_grad()
				g_img128 = network(img64)
				l2_loss = ((255*(img128-g_img128))**2).mean()
				l1_loss = (abs(255*(img128-g_img128))).mean()
				rmse_loss = rmse(img128,g_img128)
				ssim_loss = ssim(img128,g_img128)
				# tv_losss = tv_loss(255*img128,255*g_img128)
				# dloss = bce_loss(d(g_img128,img64),true_crit)
				loss = l2_loss + l1_loss  - args.l1*ssim_loss
				loss.backward()
				optimizer.step()

				# d_optimizer.zero_grad()
				# g_img128 = network(img64)
				# g_img128.detach()
				# dloss = bce_loss(d(torch.cat((img128,g_img128)),torch.cat((img64,img64))),torch.cat((true_crit,fake_crit)))#+bce_loss(d(g_img128,img64),false_crit)
				# dloss.backward()
				# d_optimizer.step()


				if idx%10 ==0:
				# Learn to denoise 128x128
				optimizer_128.zero_grad()
				optimizer_up.zero_grad()
				optimizer.zero_grad()
				g_img64_res = network(img64)
				g_img64 = img64-g_img64_res
				# g_img64.detach()
				g_img128 = network_up(torch.cat((g_img64,img64),1))
				# g_img128.detach()
				g_img128_res = network_128(g_img128)
				g_img128_denoised = g_img128-g_img128_res
				l2_loss = ((255*(g_img128_denoised-img128))**2).mean()
				l1_loss = (abs(255*(g_img128_denoised-img128))).mean()
				rmse_loss = rmse(g_img128_denoised,img128)
				ssim_loss = ssim(g_img128_denoised,img128)
				loss = l2_loss  + l1_loss - args.l1*ssim_loss
				loss.backward()
				optimizer_128.step()
				optimizer_up.step()
				optimizer.step()

				if idx%10 ==0:
					print("{} TRAINING {} {}: RMSE_LOSS:{} SSIM:{} L1:{} tv:{} TOTAL:{} ".format(args.model_name,
						epoch,idx,
						(rmse_loss.detach().cpu().numpy()),
						ssim_loss.detach().cpu().numpy(),
						l1_loss.detach().cpu().numpy(),
						0,#tv_losss.detach().cpu().numpy(),
						loss.detach().cpu().numpy()))
				
    def test(self, model):
        test_input_path = './dataset/Xu et al.\'s dataset/TEST/INPUT/'
        test_gt_path = './dataset/Xu et al.\'s dataset/TEST/GT/'
        save_path = './dataset/Xu et al.\'s dataset/M0/EPCNN/'
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        test_input_list = [
            im for im in os.listdir(test_input_path) if im.endswith('.png')
        ]
        test_gt_list = [
            im for im in os.listdir(test_gt_path) if im.endswith('.png')
        ]

        test_num = len(test_input_list)
        print('Num. of test patches: ', test_num)

        psnr_file = np.zeros(test_num)
        ssim_file = np.zeros(test_num)

        test_size = 200
        test_down_size = test_size // self.sr_scale

        with tf.Graph().as_default():
            EPCNN_input = tf.placeholder(
                shape=[None, test_down_size, test_down_size, 4],
                dtype=tf.float32)
            Tar_edge = tf.placeholder(shape=[None, test_size, test_size, 1],
                                      dtype=tf.float32)

            EPCNN_output = self.inference(EPCNN_input)
            EPCNN_output = tf.clip_by_value(EPCNN_output, 0.0, 255.0)

            para_num = np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])
            print('Num. of Parameters: ', para_num)

            var_list = [
                v for v in tf.all_variables() if v.name.startswith('EPCNN')
            ]
            saver = tf.train.Saver(var_list)

            with tf.Session() as sess:
                saver.restore(sess, os.path.join(self.model_path, model))

                for i in range(test_num):
                    ep_input, _, target_edge, _ = im2tfrecord.generatingSyntheticEdge(
                        os.path.join(test_input_path, test_input_list[i]),
                        os.path.join(test_gt_path, test_gt_list[i]))
                    ep_input = ep_input.astype(np.float32)
                    target_edge = target_edge.astype(np.float32)
                    ep_input = np.expand_dims(ep_input, axis=0)
                    target_edge = np.expand_dims(target_edge, axis=0)
                    target_edge = np.expand_dims(target_edge, axis=3)

                    output = sess.run(EPCNN_output,
                                      feed_dict={
                                          EPCNN_input: ep_input,
                                          Tar_edge: target_edge
                                      })
                    output = np.squeeze(output)
                    target_edge = np.squeeze(target_edge)
                    output = output.astype('uint8')
                    target_edge = target_edge.astype('uint8')

                    psnr_file[i] = psnr(output, target_edge)
                    ssim_file[i] = ssim(output, target_edge)

                    save_name = test_input_list[i].split('.')[0][:-5]
                    cv2.imwrite(
                        os.path.join(save_path,
                                     save_name + '_output_edge.png'), output)

                print('EPCNN: ', model)
                print('Edge PSNR: ', str(np.mean(psnr_file)))
                print('Edge SSIM: ', str(np.mean(ssim_file)))
Exemple #17
0
            for _ in range(valid_count):
                inp, tar, gen, pl, bg, bg_mask, fg, fg_m = sess.run(
                    [valid_img_from, valid_img_to, valid_generated, valid_pose_loss, valid_model[1]['background'],
                     valid_model[1]['foreground_mask'], valid_model[1]['foreground'], valid_fg_mask])
                v_inp.append(inp[0, :256, :256] / 2 + .5)
                v_tar.append(tar[0, :256, :256] / 2 + .5)
                v_gen.append(gen[0, :256, :256] / 2 + .5)
                v_pl += [pl]
                v_bg.append(bg[0, :256, :256] / 2 + .5)
                v_bg_mask.append(np.tile(bg_mask[0, :256, :256], [1, 1, 3]))
                v_fg.append(fg[0, :256, :256] / 2 + .5)
                v_fg_m.append(fg_m[0, ..., np.newaxis])

            prefix = 'test' if params['with_valid'] else 'val'
            print('- computing SSIM scores')
            ssim_score, ssim_fg, ssim_bg = ssim(v_tar, v_gen, masks=v_fg_m)
            summary = tf.Summary(value=[tf.Summary.Value(tag=f'{prefix}_metrics/ssim', simple_value=ssim_score)])
            summary_writer.add_summary(summary, i)
            summary = tf.Summary(value=[tf.Summary.Value(tag=f'{prefix}_metrics/ssim_fg', simple_value=ssim_fg)])
            summary_writer.add_summary(summary, i)
            summary = tf.Summary(value=[tf.Summary.Value(tag=f'{prefix}_metrics/ssim_bg', simple_value=ssim_bg)])
            summary_writer.add_summary(summary, i)

            print('- computing pose score')
            pl = np.mean(v_pl)
            summary = tf.Summary(value=[tf.Summary.Value(tag=f'{prefix}_metrics/pose_loss', simple_value=pl)])
            summary_writer.add_summary(summary, i)

            print('- creating images for tensorboard')
            v_inp = np.concatenate(v_inp[:16], axis=0)
            v_tar = np.concatenate(v_tar[:16], axis=0)
Exemple #18
0
def train(exp=None):
    """
    main function to run the training
    """
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = "./runs/" + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(args.checkpoint) and args.continue_train:
        # load existing model
        print("==> loading existing model")
        model_info = torch.load(args.checkpoint)
        net.load_state_dict(model_info["state_dict"])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info["optimizer"])
        cur_epoch = model_info["epoch"] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # pnsr ssim
    avg_psnrs = {}
    avg_ssims = {}
    for j in range(args.frames_output):
        avg_psnrs[j] = []
        avg_ssims[j] = []
    if args.checkdata:
        # Checking dataloader
        print("Checking Dataloader!")
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("TrainLoader checking is complete!")
        t = tqdm(validLoader, leave=False, total=len(validLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("ValidLoader checking is complete!")
        # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        # to track the training loss as the model trains
        train_losses = []
        # to track the validation loss as the model trains
        valid_losses = []
        psnr_dict = {}
        ssim_dict = {}
        for j in range(args.frames_output):
            psnr_dict[j] = 0
            ssim_dict[j] = 0
        image_log = []
        if exp is not None:
            exp.log_metric("epoch", epoch)
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batchsize
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                "trainloss": "{:.6f}".format(loss_aver),
                "epoch": "{:02d}".format(epoch),
            })
        # tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batchsize
                # record validation loss
                valid_losses.append(loss_aver)

                for j in range(args.frames_output):
                    psnr_dict[j] += psnr(pred[:, j], label[:, j])
                    ssim_dict[j] += ssim(pred[:, j], label[:, j])
                # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    "validloss": "{:.6f}".format(loss_aver),
                    "epoch": "{:02d}".format(epoch),
                })
                if i % 500 == 499:
                    for k in range(args.frames_output):
                        image_log.append(label[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                        image_log.append(pred[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                    upload_images(
                        image_log,
                        epoch,
                        exp=exp,
                        im_per_row=2,
                        rows_per_log=int(len(image_log) / 2),
                    )
        # tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        for j in range(args.frames_output):
            avg_psnrs[j].append(psnr_dict[j] / i)
            avg_ssims[j].append(ssim_dict[j] / i)
        epoch_len = len(str(args.epochs))

        print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " +
                     f"train_loss: {train_loss:.6f} " +
                     f"valid_loss: {valid_loss:.6f}" +
                     f"PSNR_1: {psnr_dict[0] / i:.6f}" +
                     f"SSIM_1: {ssim_dict[0] / i:.6f}")

        # print(print_msg)
        # clear lists to track next epoch
        if exp is not None:
            exp.log_metric("TrainLoss", train_loss)
            exp.log_metric("ValidLoss", valid_loss)
            exp.log_metric("PSNR_1", psnr_dict[0] / i)
            exp.log_metric("SSIM_1", ssim_dict[0] / i)
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            "epoch": epoch,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "avg_psnrs": avg_psnrs,
            "avg_ssims": avg_ssims,
            "avg_valid_losses": avg_valid_losses,
            "avg_train_losses": avg_train_losses,
        }
        save_flag = False
        if epoch % args.save_every == 0:
            torch.save(
                model_dict,
                save_dir + "/" +
                "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()),
            )
            print("Saved" +
                  "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()))
            save_flag = True
        if avg_psnrs[0][-1] == max(avg_psnrs[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestpsnr_1.pth",
            )
            print("Best psnr found and saved")
            save_flag = True
        if avg_ssims[0][-1] == max(avg_ssims[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestssim_1.pth",
            )
            print("Best ssim found and saved")
            save_flag = True
        if avg_valid_losses[-1] == min(avg_valid_losses) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestvalidloss.pth",
            )
            print("Best validloss found and saved")
            save_flag = True
        if not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "checkpoint.pth",
            )
            print("The latest normal checkpoint saved")
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", "wt") as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", "wt") as f:
        for i in avg_valid_losses:
            print(i, file=f)
def run_inference(subj, R, mode, k, num_sampels, num_bootsamles, batch_size,
                  num_iter, step_size, phase_step, complex_rec, use_momentum,
                  log, device):
    # Some inits of paths... Edit these
    vae_model_name = 'T2-20210415-111101/450.pth'
    vae_path = '/cluster/scratch/jonatank/logs/ddp/vae/'
    data_path = '/cluster/work/cvl/jonatank/fastMRI_T2/validation/'
    log_path = '/cluster/scratch/jonatank/logs/ddp/restore/pytorch/'
    rss = True

    # Load pretrained VAE
    path = vae_path + vae_model_name
    vae = torch.load(path, map_location=torch.device(device))
    vae.eval()

    # Data loader setup
    subj_dataset = Subject(subj, data_path, R, rss=rss)
    subj_loader = data.DataLoader(subj_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)

    # Time model and init resulting matrices
    start_time = time.perf_counter()
    rec_subj = np.zeros((len(subj_loader), 320, 320))
    gt_subj = np.zeros((len(subj_loader), 320, 320))

    # Set basic parameters
    print('Subj: ', subj, ' R: ', R, ' mode: ', mode, ' k: ', k,
          ' num_sampels: ', num_sampels, ' num_bootsamles: ', num_bootsamles,
          ' batch_size: ', batch_size, ' num_iter: ', num_iter, ' step_size: ',
          step_size, ' phase_step: ', phase_step)

    # Log
    log_path = log_path + 'R' + str(R) + '_mode' + str(
        k) + mode + '_reg2lmb0.01_' + datetime.now().strftime("%Y%m%d-%H%M%S")
    if log:
        import wandb
        wandb.login()
        wandb.init(project='JDDP' + '_T2',
                   name=vae_model_name,
                   config={
                       "num_iter": num_iter,
                       "step_size": step_size,
                       "phase_step": phase_step,
                       "mode": mode,
                       'R': R,
                       'K': k,
                       'use_momentum': use_momentum
                   })
        #wandb.watch(vae)
    else:
        wandb = False

    print("num_iter", num_iter, " step_size ", step_size, " phase_step ",
          phase_step, " mode ", mode, ' R ', R, ' K ', k, 'use_momentum',
          use_momentum)

    for batch in tqdm(subj_loader, desc="Running inference"):
        ksp, coilmaps, rss, norm_fact, num_sli = batch

        rec_sli = vaerecon(ksp[0],
                           coilmaps[0],
                           mode,
                           vae,
                           rss[0],
                           log_path,
                           device,
                           writer=wandb,
                           norm=norm_fact.item(),
                           nsampl=num_sampels,
                           boot_samples=num_bootsamles,
                           k=k,
                           patchsize=28,
                           parfact=batch_size,
                           num_iter=num_iter,
                           stepsize=step_size,
                           lmb=phase_step,
                           use_momentum=use_momentum)

        rec_subj[num_sli] = np.abs(center_crop(rec_sli.detach().cpu().numpy()))
        gt_subj[num_sli] = np.abs(center_crop(rss[0]))

        rmse_sli = nmse(rec_subj[num_sli], gt_subj[num_sli])
        ssim_sli = ssim(rec_subj[num_sli], gt_subj[num_sli])
        psnr_sli = psnr(rec_subj[num_sli], gt_subj[num_sli])
        print('Slice: ', num_sli.item(), ' RMSE: ', str(rmse_sli), ' SSIM: ',
              str(ssim_sli), ' PSNR: ', str(psnr_sli))
        end_time = time.perf_counter()

        print(f"Elapsed time for {str(num_sli)} slices: {end_time-start_time}")

    rmse_v = nmse(recon_subj, gt_subj)
    ssim_v = nmse(recon_subj, gt_subj)
    psnr_v = nmse(recon_subj, gt_subj)
    print('Subject Done: ', 'RMSE: ', str(rmse_sli), ' SSIM: ', str(ssim_sli),
          ' PSNR: ', str(psnr_sli))

    pickle.dump(
        recon_subj,
        open(log_path + subj + str(k) + mode + str(restore_sense) + str(R),
             'wb'))

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(subj_loader)} slices: {end_time-start_time}")
Exemple #20
0
def vaerecon(ksp,
             coilmaps,
             mode,
             vae_model,
             gt,
             logdir,
             device,
             writer=False,
             norm=1,
             nsampl=100,
             boot_samples=500,
             k=1,
             patchsize=28,
             parfact=25,
             num_iter=200,
             stepsize=5e-4,
             lmb=0.01,
             num_priors=1,
             use_momentum=True):
    # Init data
    imcoils, imsizer, imsizec = ksp.shape
    ksp = ksp.to(device)
    coilmaps = coilmaps.to(device)
    vae_model = vae_model.to(device)
    uspat = (torch.abs(ksp[0]) > 0).type(torch.uint8).to(device)
    recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)
    rss = rss_pytorch(ksp)

    # Init coilmaps estimation with JSENSE
    if mode == 'JDDP':
        # Polynomial order
        max_basis_order = 6
        num_coeffs = (max_basis_order + 1)**2

        # Create the basis functions for the sense estimation estimation
        basis_funct = create_basis_functions(imsizer,
                                             imsizec,
                                             max_basis_order,
                                             show_plot=False)
        plot_basis = False
        if plot_basis:
            for i in range(num_coeffs):
                writer.log({
                    "Basis funcs": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.from_numpy(basis_funct[i, :, :]))),
                                     caption="")
                    ]
                })

        basis_funct = torch.from_numpy(
            np.tile(basis_funct[np.newaxis, :, :, :],
                    [coilmaps.shape[0], 1, 1, 1])).to(device)
        coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat)

        coilmaps = torch.sum(
            coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
            1).to(device)

        recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)

        if writer:
            for i in range(coilmaps.shape[0]):
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
        print("Coilmaps init done")

    # Log
    if writer:
        writer.log(
            {
                "Gt rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(gt)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored rss": [
                    writer.Image(transforms.ToPILImage()(
                        normalize_tensor(rss)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored abs": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.abs(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored Phase": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.angle(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "diff rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                 caption="")
                ]
            },
            step=0)
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        nmse_v = nmse(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)
        writer.log({"SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v}, step=0)

        lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact,
                              nsampl, vae_model)
        writer.log({"ELBO": lik}, step=0)
        writer.log({"DC err": dc}, step=0)

    t = 1
    for it in range(0, num_iter, 2):
        print('Itr: ', it)

        # Magnitude prior projection step
        for _ in range(num_priors):
            # Gradient descent of Prior
            if mode == 'TV':
                tvnorm, abstvgrad = tv_norm(torch.abs(rss))
                priorgrad = abstvgrad * recs_gpu / (torch.abs(recs_gpu))
                recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  #and it%10 == 0:
                    writer.log(
                        {
                            "TVgrad": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(abstvgrad)),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log(
                        {
                            "TV": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(tvnorm)),
                                             caption="")
                            ]
                        },
                        step=it + 1)

            elif mode == 'DDP' or mode == 'JDDP':
                g_abs_lik, est_uncert, g_dc = prior_gradient(
                    rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl,
                    vae_model, boot_samples, mode)
                priorgrad = g_abs_lik * recs_gpu / (torch.abs(recs_gpu))

                if it > -1:
                    recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  # Log
                    writer.log(
                        {
                            "VAEgrad abs": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(torch.abs(g_abs_lik))),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log({"STD": torch.mean(torch.abs(est_uncert))},
                               step=it + 1)

                    tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
                    tmp2 = ksp * uspat.unsqueeze(0)
                    tmp = tmp1 + tmp2
                    rss = rss_pytorch(tmp)
                    nmse_v = nmse(
                        (rss[160:-160].detach().cpu().numpy() / norm),
                        gt[160:-160].detach().cpu().numpy())
                    ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ',
                          psnr_v)
                    writer.log({
                        "SSIM": ssim_v,
                        "NMSE": nmse_v,
                        "PSNR": psnr_v
                    },
                               step=it + 1)
            else:
                print("Error: Prior method does not exists.")
                exit()

        # Phase projection step
        if lmb > 0:
            tmpa = torch.abs(recs_gpu)
            tmpp = torch.angle(recs_gpu)

            # We apply phase regularization to prefer smooth phase images
            #tmpptv = reg2_proj(tmpp, imsizer, imsizec, alpha=lmb, niter=2)  # 0.1, 15
            tmpptv = tv_proj(tmpp, mu=0.125, lmb=lmb, IT=50)  # 0.1, 15
            # We combine back the phase and the magnitude
            recs_gpu = tmpa * torch.exp(1j * tmpptv)

        # Coilmaps estimation step (if JSENSE)
        if mode == 'JDDP':
            # computed on cpu since pytorch gpu can handle complex numbers...
            coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct,
                                               uspat)
            coilmaps = torch.sum(
                coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
                1).to(device)

            if writer:
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)

        # Data consistency projection
        tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
        tmp2 = ksp * uspat.unsqueeze(0)
        tmp = tmp1 + tmp2
        recs_gpu = tFT_pytorch(tmp, coilmaps)
        # recs[it + 2] = recs_gpu.detach().cpu().numpy()
        rss = rss_pytorch(tmp)

        # Log
        nmse_v = nmse((rss[160:-160].detach().cpu().numpy() / norm),
                      gt[160:-160].detach().cpu().numpy())
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)

        if writer:
            writer.log({
                "SSIM": ssim_v,
                "NMSE": nmse_v,
                "PSNR": psnr_v
            },
                       step=it + 1)
            writer.log(
                {
                    "Restored rss": [
                        writer.Image(transforms.ToPILImage()(
                            normalize_tensor(rss)),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored Phase": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.angle(recs_gpu))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "diff rss": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored 1ch kspace": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.log(torch.abs(tmp[0])))),
                                     caption="")
                    ]
                },
                step=it + 1)
            lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize,
                                  parfact, nsampl, vae_model)
            writer.log({"ELBO": lik}, step=it + 1)
            writer.log({"DC err": dc}, step=it + 1)

    return rss / norm
Exemple #21
0
    )  # Note: here use tf.reduce_sum, not use tf.reduce_mean
    loss_texture = -loss_discrim

    correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1),
                                   tf.argmax(discrim_target, 1))
    discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

    # 2) content loss

    CX_LAYER = 'conv4_2'

    enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
    dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))

    # SSIM loss
    ssim_loss = 25 * (1 - utils.ssim(dslr_image, enhanced) / batch_size)

    # CX loss
    cx_loss = 4 * CX_loss_helper(dslr_vgg[CX_LAYER], enhanced_vgg[CX_LAYER],
                                 config_CX)

    # content loss
    loss_content = ssim_loss + cx_loss

    # 3) color loss

    enhanced_blur = utils.blur(enhanced)
    dslr_blur = utils.blur(dslr_image)

    loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur,
                                      2)) / (2 * batch_size)
Exemple #22
0
 # 4. adjust the learning rate
 scheduler_G.step()  #更新学习率
 scheduler_D.step()
 ''' validation '''
 current_psnr_val = psnr_val
 psnr_val = 0.
 ssim_val = 0.
 with torch.no_grad():
     net.eval()
     for i, (ms, pan, gt) in enumerate(loader['validation']):
         ms, _ = normlization(ms.cuda())
         pan, _ = normlization(pan.cuda())
         gt, _ = normlization(gt.cuda())
         imgf = net(ms, pan)
         psnr_val += psnr_loss(imgf, gt, 1.)
         ssim_val += ssim(imgf, gt, 5, 'mean', 1.)
     psnr_val = float(psnr_val / loader['validation'].__len__())
     ssim_val = float(ssim_val / loader['validation'].__len__())
 writer.add_scalar('PSNR on validation data', psnr_val, epoch)
 writer.add_scalar('SSIM on validation data', ssim_val, epoch)
 ''' save model '''
 # Save the best weight
 #    if best_psnr_val<psnr_val and best_ssim_val<ssim_val:
 #    if best_ssim_val<ssim_val:
 if best_psnr_val < psnr_val:
     best_psnr_val = psnr_val
     best_ssim_val = ssim_val
     torch.save(
         {
             'G': net.state_dict(),
             'D': discriminator.state_dict(),
def test(data, lambda_y, lambda_m,
         weights=None,
         batch_size=16,
         img_size=416,
         conf_thres=0.001,
         iou_thres=0.6,  # for nms
         save_json=False,
         single_cls=False,
         augment=False,
         model=None,
         dataloader=None):
    # Initialize/load model and set device
    if model is None:
        device = select_device(opt.device, batch_size=batch_size)
        verbose = opt.task == 'test'

        # Remove previous
        for f in glob.glob('test_batch*.png'):
            os.remove(f)

        # Initialize model
        model = MainModel(img_size)

        # Load weights
        attempt_download(weights)
        if weights.endswith('.pt'):  # pytorch format
            model.load_state_dict(torch.load(weights, map_location=device)['model'])
        else:  # darknet format
            load_darknet_weights(model, weights)

        # Fuse
        model.fuse()
        model.to(device)

        if device.type != 'cpu' and torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
    else:  # called by train.py
        device = next(model.parameters()).device  # get model device
        verbose = False

    # Configure run
    #data = parse_data_cfg(data)
    nc = 4 # 1 if single_cls else int(data['classes'])  # number of classes 4
    path = "./data/customdata/custom_test.txt" #data['valid']  # path to test images
    names = ['hardhat', 'vest', 'mask', 'boots'] #load_classes(data['names'])  # class names ['hardhat', 'vest', 'mask', 'boots']
    iouv = torch.linspace(0.5, 0.95, 10).to(device)  # iou vector for [email protected]:0.95
    iouv = iouv[0].view(1)  # comment for [email protected]:0.95
    niou = iouv.numel()
    yolo_loss = 0
    ssim_loss = 0

    # Dataloader
    if dataloader is None:
        dataset = LoadImagesAndLabels(path, img_size, batch_size, rect=True, single_cls=opt.single_cls)
        batch_size = min(batch_size, len(dataset))
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]),
                                pin_memory=True,
                                collate_fn=dataset.collate_fn)

    seen = 0
    model.eval()
    _ = model(torch.zeros((1, 3, img_size, img_size), device=device)) if device.type != 'cpu' else None  # run once
    coco91class = coco80_to_coco91_class()
    s = ('%20s' + '%10s' * 7) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', 'F1', 'SSIM Loss')
    p, r, f1, mp, mr, map, mf1, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
    yolo_loss = torch.zeros(3, device=device)
    ssim_loss = torch.zeros(1, device=device)
    jdict, stats, ap, ap_class = [], [], [], []
    for batch_i, (imgs, targets, paths, shapes, midas) in enumerate(tqdm(dataloader, desc=s)):
        imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
        targets = targets.to(device)
        nb, _, height, width = imgs.shape  # batch size, channels, height, width
        whwh = torch.Tensor([width, height, width, height]).to(device)
        midas = midas.to(device).float() / 255.0
        # Plot images with bounding boxes
        f = 'test_batch%g.png' % batch_i  # filename
        if batch_i < 1 and not os.path.exists(f):
            plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)

        # Disable gradients
        with torch.no_grad():
            # Run model
            t = time_synchronized()
            inf_out, train_out, midas_out = model(imgs, augment=augment)  # inference and training outputs
            t0 += time_synchronized() - t

            # Compute loss
            if hasattr(model, 'hyp'):  # if model has loss hyperparameters
                yolo_loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls
                midas = midas.unsqueeze(1)
                ssim_loss += 1 - ssim(midas_out, midas)
                loss = lambda_y * yolo_loss + lambda_m * ssim_loss

            # Run NMS
            t = time_synchronized()
            output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)  # nms
            t1 += time_synchronized() - t

        # Statistics per image
        for si, pred in enumerate(output):
            labels = targets[targets[:, 0] == si, 1:]
            nl = len(labels)
            tcls = labels[:, 0].tolist() if nl else []  # target class
            seen += 1

            if pred is None:
                if nl:
                    stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
                continue

            # Append to text file
            # with open('test.txt', 'a') as file:
            #    [file.write('%11.5g' * 7 % tuple(x) + '\n') for x in pred]

            # Clip boxes to image bounds
            clip_coords(pred, (height, width))

            # Append to pycocotools JSON dictionary
            if save_json:
                # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
                image_id = int(Path(paths[si]).stem.split('_')[-1])
                box = pred[:, :4].clone()  # xyxy
                scale_coords(imgs[si].shape[1:], box, shapes[si][0], shapes[si][1])  # to original shape
                box = xyxy2xywh(box)  # xywh
                box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
                for p, b in zip(pred.tolist(), box.tolist()):
                    jdict.append({'image_id': image_id,
                                  'category_id': coco91class[int(p[5])],
                                  'bbox': [round(x, 3) for x in b],
                                  'score': round(p[4], 5)})

            # Assign all predictions as incorrect
            correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
            if nl:
                detected = []  # target indices
                tcls_tensor = labels[:, 0]

                # target boxes
                tbox = xywh2xyxy(labels[:, 1:5]) * whwh

                # Per target class
                for cls in torch.unique(tcls_tensor):
                    ti = (cls == tcls_tensor).nonzero().view(-1)  # prediction indices
                    pi = (cls == pred[:, 5]).nonzero().view(-1)  # target indices

                    # Search for detections
                    if pi.shape[0]:
                        # Prediction to target ious
                        ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1)  # best ious, indices

                        # Append detections
                        for j in (ious > iouv[0]).nonzero():
                            d = ti[i[j]]  # detected target
                            if d not in detected:
                                detected.append(d)
                                correct[pi[j]] = ious[j] > iouv  # iou_thres is 1xn
                                if len(detected) == nl:  # all targets already located in image
                                    break

            # Append statistics (correct, conf, pcls, tcls)
            stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

    # Compute statistics
    stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy
    if len(stats):
        p, r, ap, f1, ap_class = ap_per_class(*stats)
        if niou > 1:
            p, r, ap, f1 = p[:, 0], r[:, 0], ap.mean(1), ap[:, 0]  # [P, R, [email protected]:0.95, [email protected]]
        mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
        nt = np.bincount(stats[3].astype(np.int64), minlength=nc)  # number of targets per class
    else:
        nt = torch.zeros(1)

    # Print results
    pf = '%20s' + '%10.3g' * 7  # print format
    print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1, ssim_loss))

    # Print results per class
    if verbose and nc > 1 and len(stats):
        for i, c in enumerate(ap_class):
            print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))

    # Print speeds
    if verbose or save_json:
        t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (img_size, img_size, batch_size)  # tuple
        print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)


    maps = np.zeros(nc) + map
    for i, c in enumerate(ap_class):
        maps[c] = ap[i]
    return (mp, mr, map, mf1, ssim_loss, *(loss.cpu() / len(dataloader)).tolist()), maps
Exemple #24
0
def main():
    ii = 1
    color_dim = 1

    for loss in loss_lists:
        print('loss:', loss)
        if loss == 'perceptual':
            color_dim = 3
        else:
            color_dim = 1

        method = Methods(
            scale=scale,
            image_size=image_size,
            label_size=label_size,
            color_dim=color_dim,
            is_training=False,
        )

        X_pre_test, X_test, Y_test = load_test(scale=scale,
                                               test_folder=test_folder,
                                               color_dim=color_dim)

        predicted_list = []

        weight_filename = '../srcnn/S' + str(
            scale) + 'Models/' + methodName + '/' + loss + '/model0040.hdf5'
        model = method.RDN()
        model.load_weights(weight_filename)

        for img in X_pre_test:
            img = img.astype('float')

            test_sample = img.reshape(1, img.shape[0], img.shape[1], color_dim)
            predicted = model.predict(test_sample)
            predicted_list.append(
                predicted.reshape(predicted.shape[1], predicted.shape[2],
                                  color_dim))

        n_img = len(predicted_list)
        dirname = 'resultNew/S3_' + methodName + '/' + loss
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        psnr_bic = 0
        psnr_cnn = 0
        ssim_bic = 0
        ssim_cnn = 0

        dirname_gnd = './TestDatasets/' + test_folder
        img_list = os.listdir(dirname_gnd)
        for i in range(n_img):
            imgname = 'image{:02}'.format(i)
            low_res = X_pre_test[i]
            bic = np.float32(X_test[i])
            gnd = np.float32(Y_test[i])
            cnn = np.float32((predicted_list[i] * 1))

            if color_dim > 1:
                cnn = cnn[:, :, 0]
                gnd = gnd[:, :, 0]
                bic = bic[:, :, 0]

            cnn = cnn[scale:-scale, scale:-scale]
            bic = bic[scale:-scale, scale:-scale]
            gnd = gnd[scale:-scale, scale:-scale]

            name = os.path.splitext(os.path.basename(img_list[i]))[0]

            #cv2.imwrite(os.path.join(dirname,imgname+'_original.bmp'), low_res)
            cv2.imwrite(os.path.join(dirname, name + '_bic.bmp'), bic)
            cv2.imwrite(os.path.join(dirname, name + '_gnd.bmp'), gnd)
            cv2.imwrite(os.path.join(dirname, name + '_cnn.bmp'), cnn)

            bic_ps = psnr(gnd, bic)
            cnn_ps = psnr(gnd, cnn)
            bic_ss = ssim(gnd, bic)
            cnn_ss = ssim(gnd, cnn)
            print('bic:', bic_ps)
            print('cnn', cnn_ps)

            psnr_bic += bic_ps
            psnr_cnn += cnn_ps
            ssim_bic += bic_ss
            ssim_cnn += cnn_ss

        psnr_bic_m = psnr_bic / n_img
        psnr_cnn_m = psnr_cnn / n_img
        ssim_bic_m = ssim_bic / n_img
        ssim_cnn_m = ssim_cnn / n_img

        srcnn_cnn_accs[0][0] = psnr_bic_m
        srcnn_cnn_accs[0][1] = ssim_bic_m
        srcnn_cnn_accs[ii][0] = psnr_cnn_m
        srcnn_cnn_accs[ii][1] = ssim_cnn_m
        ii += 1
        sio.savemat(
            methodName + str(scale) + '_' + test_folder + '_cnn_accs', {
                methodName + str(scale) + '_' + test_folder + '_cnn_accs':
                srcnn_cnn_accs
            },
            appendmat=True)
        #####################
        print('psnr_bic_m:', psnr_bic_m)
        print('psnr_cnn_m:', psnr_cnn_m)
        print('ssim_bic_m:', ssim_bic_m)
        print('ssim_cnn_m:', ssim_cnn_m)
Exemple #25
0
def main(args):
	device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
	utils.setup_experiment(args)
	utils.init_logging(args)

	# Build data loaders, a model and an optimizer
	model = models.build_model(args).to(device)
	print(model)
	optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80, 90, 100], gamma=0.5)
	logging.info(f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters")
	
	if args.resume_training:
		state_dict = utils.load_checkpoint(args, model, optimizer, scheduler)
		global_step = state_dict['last_step']
		start_epoch = int(state_dict['last_step']/(403200/state_dict['args'].batch_size))+1
	else:
		global_step = -1
		start_epoch = 0
		
	train_loader, valid_loader, _ = data.build_dataset(args.dataset, args.data_path, batch_size=args.batch_size)
	
	# Track moving average of loss values
	train_meters = {name: utils.RunningAverageMeter(0.98) for name in (["train_loss", "train_psnr", "train_ssim"])}
	valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}
	writer = SummaryWriter(log_dir=args.experiment_dir) if not args.no_visual else None

	for epoch in range(start_epoch, args.num_epochs):
		if args.resume_training:
			if epoch %10 == 0:
				optimizer.param_groups[0]["lr"] /= 2
				print('learning rate reduced by factor of 2')
				
		train_bar = utils.ProgressBar(train_loader, epoch)
		for meter in train_meters.values():
			meter.reset()

		for batch_id, inputs in enumerate(train_bar):
			model.train()

			global_step += 1
			inputs = inputs.to(device)
			noise = utils.get_noise(inputs, mode = args.noise_mode, 
												min_noise = args.min_noise/255., max_noise = args.max_noise/255.,
												noise_std = args.noise_std/255.)

			noisy_inputs = noise + inputs;
			outputs = model(noisy_inputs)
			loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2)

			model.zero_grad()
			loss.backward()
			optimizer.step()

			train_psnr = utils.psnr(outputs, inputs)
			train_ssim = utils.ssim(outputs, inputs)
			train_meters["train_loss"].update(loss.item())
			train_meters["train_psnr"].update(train_psnr.item())
			train_meters["train_ssim"].update(train_ssim.item())
			train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)

			if writer is not None and global_step % args.log_interval == 0:
				writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step)
				writer.add_scalar("loss/train", loss.item(), global_step)
				writer.add_scalar("psnr/train", train_psnr.item(), global_step)
				writer.add_scalar("ssim/train", train_ssim.item(), global_step)
				gradients = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None], dim=0)
				writer.add_histogram("gradients", gradients, global_step)
				sys.stdout.flush()

		if epoch % args.valid_interval == 0:
			model.eval()
			for meter in valid_meters.values():
				meter.reset()

			valid_bar = utils.ProgressBar(valid_loader)
			for sample_id, sample in enumerate(valid_bar):
				with torch.no_grad():
					sample = sample.to(device)
					noise = utils.get_noise(sample, mode = 'S', 
												noise_std = (args.min_noise +  args.max_noise)/(2*255.))

					noisy_inputs = noise + sample;
					output = model(noisy_inputs)
					valid_psnr = utils.psnr(output, sample)
					valid_meters["valid_psnr"].update(valid_psnr.item())
					valid_ssim = utils.ssim(output, sample)
					valid_meters["valid_ssim"].update(valid_ssim.item())

					if writer is not None and sample_id < 10:
						image = torch.cat([sample, noisy_inputs, output], dim=0)
						image = torchvision.utils.make_grid(image.clamp(0, 1), nrow=3, normalize=False)
						writer.add_image(f"valid_samples/{sample_id}", image, global_step)

			if writer is not None:
				writer.add_scalar("psnr/valid", valid_meters['valid_psnr'].avg, global_step)
				writer.add_scalar("ssim/valid", valid_meters['valid_ssim'].avg, global_step)
				sys.stdout.flush()

			logging.info(train_bar.print(dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"])))
			utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_psnr"].avg, mode="max")
		scheduler.step()

	logging.info(f"Done training! Best PSNR {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}.")
Exemple #26
0
def display_denoising(DnCNN,
                      BF_DnCNN,
                      set12_path,
                      image_num=7,
                      noise_level=90,
                      l=0,
                      h=10,
                      model='DnCNN'):

    clean_im = single_image_loader(set12_path, image_num)
    clean_im_tensor = torch.from_numpy(clean_im).unsqueeze(0).unsqueeze(0).to(
        device).float()

    noise = utils.get_noise(clean_im_tensor,
                            noise_std=noise_level / 255.,
                            mode='S')
    inp_test = clean_im_tensor + noise
    noisy_psnr = np.round(utils.psnr(clean_im_tensor, inp_test), 2)
    noisy_ssim = np.round(utils.ssim(clean_im_tensor, inp_test), 2)

    denoised_dncnn = DnCNN(inp_test)
    denoised_dncnn_psnr = np.round(utils.psnr(clean_im_tensor, denoised_dncnn),
                                   2)
    denoised_dncnn_ssim = np.round(utils.ssim(clean_im_tensor, denoised_dncnn),
                                   2)
    denoised_dncnn = denoised_dncnn.cpu().data.squeeze(0).squeeze(0).numpy()

    denoised_bf_dncnn = BF_DnCNN(inp_test)
    denoised_bf_dncnn_psnr = np.round(
        utils.psnr(clean_im_tensor, denoised_bf_dncnn), 2)
    denoised_bf_dncnn_ssim = np.round(
        utils.ssim(clean_im_tensor, denoised_bf_dncnn), 2)
    denoised_bf_dncnn = denoised_bf_dncnn.cpu().data.squeeze(0).squeeze(
        0).numpy()
    noisy_im = inp_test.cpu().data.squeeze(0).squeeze(0).numpy()

    f, axs = plt.subplots(1, 4, figsize=(15, 4), squeeze=True)

    f.suptitle(r'Training range: $\sigma \in [ $' + str(l) + ' , ' + str(h) +
               ']',
               fontname='Times New Roman',
               fontsize=15)

    axs[0].imshow(clean_im, 'gray', vmin=0, vmax=1)
    axs[0].set_title('clean image', fontname='Times New Roman', fontsize=15)

    axs[1].imshow(noisy_im, 'gray', vmin=0, vmax=1)
    axs[1].set_title(r'noisy image, $\sigma$ = ' + str(noise_level),
                     fontname='Times New Roman',
                     fontsize=15)
    axs[1].set_xlabel('psnr ' + str(noisy_psnr) + '\n ssim ' + str(noisy_ssim),
                      fontname='Times New Roman',
                      fontsize=15)

    axs[2].imshow(denoised_dncnn, 'gray', vmin=0, vmax=1)
    axs[2].set_title('denoised, ' + model,
                     fontname='Times New Roman',
                     fontsize=15)
    axs[2].set_xlabel('psnr ' + str(denoised_dncnn_psnr) + '\n ssim ' +
                      str(denoised_dncnn_ssim),
                      fontname='Times New Roman',
                      fontsize=15)

    axs[3].imshow(denoised_bf_dncnn, 'gray', vmin=0, vmax=1)
    axs[3].set_title('denoised, BF_' + model,
                     fontname='Times New Roman',
                     fontsize=15)
    axs[3].set_xlabel('psnr ' + str(denoised_bf_dncnn_psnr) + '\n ssim ' +
                      str(denoised_bf_dncnn_ssim),
                      fontname='Times New Roman',
                      fontsize=15)

    for i in range(4):
        axs[i].tick_params(bottom=False,
                           left=False,
                           labelleft=False,
                           labelbottom=False)
Exemple #27
0
            ki = k[i,:,:,d]
            tmp[d,:,:] += conv2(phi_u,ki[::-1,::-1],'full')
    u_t[t+1] = np.clip(u_t[t] - crop_zero(tmp, padding, padding) - l*bwd_bayer(fwd_bayer(u_t[t]) - f), 0.0,255.0)
    print '.',

#Evaluate
print "\nTest image: %d" % data_config['indices'][example]
#get the result
result = u_t[num_steps]
plt.figure(1)
plt.imshow(swapimdims_3HW_HW3(result).astype('uint8'), interpolation="none")
plt.show()
target = data.target[example]
#compute psnr and ssim on the linear space result image
print "PSNR linear: %.2f dB" % psnr(target, np.round(result), 255.0)
print "SSIM linear: %.3f" % ssim(target, result, 255.0)

#also compute psnr and ssim on the sRGB transformed result image
srgb_params = init_colortransformation_gamma()
result_rgb = apply_colortransformation_gamma(np.expand_dims(result,0), srgb_params)
target_rgb = apply_colortransformation_gamma(np.expand_dims(target,0), srgb_params)
print "PSNR sRGB: %.2f dB" % psnr(target_rgb[0], result_rgb[0], 255.0)
print "SSIM sRGB: %.3f" % ssim(target_rgb[0], result_rgb[0], 255.0)

#save result
plt.imsave("results/" + str(data_config['indices'][example]), swapimdims_3HW_HW3(result).astype('uint8'))