예제 #1
0
 def optimization_loop(self):
     # add normal noise to the network parameters
     if self.args.param_noise:
         for n in [x for x in self.net.parameters() if len(x.size()) in [4, 5]]:
             n = n + n.detach().clone().normal_() * n.std() * 0.02
     
     # add normal noise to input noise tensor
     input_ = self.input_old
     if self.args.reg_noise_std > 0:
         input_ = self.input_old + (self.add_noise_.normal_() * self.args.reg_noise_std)
     
     # add data to input noise tensor
     if self.iiter < self.args.data_forgetting_factor:
         input_ += self.add_data_weight[self.iiter] * self.add_data_
         self.input_list.append(u.torch_to_np(input_[0, 0]))
     
     # compute output
     out_ = self.net(input_)
     
     # compute the main loss function
     main_loss = self.loss_fn(out_ * self.mask_, self.coarse_img_)
     
     # compute regularization loss
     reg_loss = self.loss_reg_fn(out_, self.reg(out_, thresh=self.reg_th))
     
     # compute total loss
     reg_w = self.reg_weight[self.iiter-2000] if self.iiter > 2000 else 0.01
     total_loss = main_loss + reg_w * reg_loss
     
     total_loss.backward()
     
     # save loss and metrics, and print log
     l = total_loss.item()
     r = reg_loss.item()
     s = u.snr(output=out_, target=self.img_).item()
     p = u.pcorr(output=out_, target=self.img_).item()
     self.history.append((l, r, s, p))
     self.history.lr.append(self.optimizer.param_groups[0]['lr'])
     print(colored(self.history.log_message(self.iiter), 'yellow'), '\r', end='')
     
     # save the output if the loss is decreasing
     if self.iiter == 0:
         self.loss_min = self.history.loss[-1]
         self.out_best = u.torch_to_np(out_).squeeze() if out_.ndim > 4 else u.torch_to_np(out_).transpose((1, 2, 0))
     elif self.history.loss[-1] <= self.loss_min:
         self.loss_min = self.history.loss[-1]
         self.out_best = u.torch_to_np(out_).squeeze() if out_.ndim > 4 else u.torch_to_np(out_).transpose((1, 2, 0))
     else:
         pass
     
     # save intermediate outputs
     if self.iiter in self.iter_to_be_saved and self.iiter != 0:
         out_img = u.torch_to_np(out_).squeeze() if out_.ndim > 4 else u.torch_to_np(out_).transpose((1, 2, 0))
         np.save(os.path.join(self.outpath,
                              self.image_name.split('.')[0] + '_output%s.npy' % str(self.iiter).zfill(self.zfill)),
                 out_img)
     
     self.iiter += 1
     
     return total_loss
예제 #2
0
def splot(y, y0, yd, title="Denoising"):
    """ Plot the denoised signal thanks to the filter filt with its parameters parameters.
    y : noisy signal
    y0 : raw signal
    yd : denoised signal
    mu : gaussian filter parameter
    """
    fig = plt.figure(figsize=(20, 12))
    _y0 = y0[:2000]
    _y = y[:2000]
    _yd = yd[:2000]
    plt.subplot(221)
    plt.plot(_y0)
    plt.title('Raw signal :')
    plt.subplot(222)
    plt.plot(_y)
    plt.title('Noised signal')
#    plt.plot(utils.gaussian_filter(y, mu))
#    plt.title('Result for the gaussian filter - SNR :' + str(utils.snr(y0, utils.gaussian_filter(y, mu))))
    plt.subplot(223)
    plt.plot(_yd, "r")
    plt.plot(_y0, linewidth=2.5, alpha=0.3)
    plt.title('Denoised signal - SNR : %0.2f dB' % utils.snr(y0, yd))
    plt.subplot(224)
    plt.plot(_y0 - _yd)
    plt.title('Differences between raw and denoised signal :')
    fig.suptitle(title, fontsize=30, fontweight="bold")
예제 #3
0
    def test_inference(self, signal, sample_rate=16000):
        if len(signal.shape) < 2:
            length = signal.shape[0]
        else:
            length = signal.shape[1]
        acts, meta = self.infer(signal)

        signal = np.squeeze(signal)
        padded_signal = np.concatenate([signal, np.zeros(self.kernel_size-1)])
        padded_length = len(padded_signal)
        times = np.arange(padded_length) / sample_rate
        fig, axes = plt.subplots(3, 1, sharex=True)
        ax = axes[0]
        ax.plot(times, padded_signal)
        ax.set_title('Original signal')
        ax = axes[1]
        recon = np.squeeze(self.reconstruction(acts).detach().cpu().numpy())
        ax.plot(times, recon)
        ax.set_title('Reconstruction')

        np_acts = np.squeeze(acts.detach().cpu().numpy())
        np_acts = np.concatenate([np.zeros([self.n_kernel, self.kernel_size-1]),
                                  np_acts],
                                 axis=1)
        utils.plot_spikegram(np_acts,
                             sample_rate=sample_rate, markerSize=1, ax=axes[2])
        print("Signal-noise ratio: {:f} dB".format(utils.snr(padded_signal, recon)))
        return acts, meta
예제 #4
0
def opt(parameters, filt, y, y0):
    """plot the graph of snr, and return the best value of parameters.
    
    Parameter
    ------------
    parameters : parameters for the function filt
    filt : filter
    y : noisy signal
    y0 : original signal
    """
    
    snrlist = np.zeros(len(parameters))
    for i in range(len(parameters)): 
        snrlist[i] = utils.snr(y0, filt(y, parameters[i]))
    i = np.argmax(snrlist)
    return parameters[i]
def PoiSelection(x_train, y_train,x_test, y_test, poi_type='corr', poi_num = 10, poi_idx=[]):
    
    # PoI selection module,select poi_num POIs according to the highest SNR and Pearson's correlation coefficent. Or just feed it with your own POI list.
    # Example:
    # X_profiling_poi = PoiSelection(X_profiling,poi_type='corr',poi_num=20)
    
    if poi_type == 'corr':
        m = -np.abs(corr(x_train,y_train))
        poi_idx = np.argsort(m,axis=0)[:poi_num]
    elif poi_type == 'snr':
        m = -snr(x_train,y_train)
        poi_idx = np.argsort(m,axis=0)[:poi_num]
    elif poi_type == 'custom':
        pass
    else:
        print('Error:unknown poi type')
        
    x_train = x_train[:,poi_idx]
    x_test = x_test[:,poi_idx]
    return x_train,x_test
#création de la table des snr
snrframe = pd.DataFrame(index=['hard', 'soft'])

#débruitage et optimisation du filtre
Yh = np.zeros(Y.shape)
Ys = np.zeros(Y.shape)
Yg = np.zeros(Y.shape)
parameters = np.linspace(1, 5, 600)


#def filt(y, tau):
 #   return utils.gaussian_filter(y, tau*sigma)
for i in range(Y.shape[0]):
    mug = sc.opt(parameters, utils.gaussian_filter, Yb[i], Y[i])
    Yg[i] = utils.gaussian_filter(Yb[i], mug*sigma)
snrframe.loc[:,'gaussian'] = pd.Series(data=[utils.snr(Y, Yg), utils.snr(Y, Yg) ], index=snrframe.index)
snrframe.loc[:,'noisy'] = pd.Series(data=[utils.snr(Y, Yb), utils.snr(Y, Yb) ], index=snrframe.index)

for w in pywt.wavelist('db'):
    print w
    def filt_hard (y, tau):
            return utils.wave_hard_filter(y, sigma, tau, w)
    def filt_soft(y, tau):
        return utils.wave_soft_filter(y, sigma, tau, w)
    for i in range(Y.shape[0]):
        tauhard = sc.opt(parameters, filt_hard, Yb[i], Y[i])
        tausoft = sc.opt(parameters, filt_soft, Yb[i], Y[i])
        Yh[i] = utils.wave_hard_filter(Yb[i], sigma, tauhard, w)
        Ys[i] = utils.wave_soft_filter(Yb[i], sigma, tausoft, w)
    snrframe.loc[:,w] = pd.Series(data=[utils.snr(Y, Yh), utils.snr(Y, Ys) ], index=snrframe.index)
예제 #7
0
    def optimization_loop(self):
        # adding normal noise to the learned parameters
        if self.args.param_noise:
            for n in [x for x in self.net.parameters() if len(x.size()) in [4, 5]]:
                n = n + n.detach().clone().normal_() * n.std() * 0.02
        
        # adding normal noise to the input tensor
        input_ = self.input_old
        if self.args.reg_noise_std > 0:
            input_ = self.input_old + (self.add_noise_.normal_() * self.args.reg_noise_std)
        
        # adding data to the input noise
        if self.iiter < self.args.data_forgetting_factor:
            input_ += self.add_data_weight[self.iiter] * self.add_data_
            self.input_list.append(u.torch_to_np(input_, True))
        
        # compute output
        out_ = self.net(input_)

        # compute the main loss function
        main_loss = self.loss_fn(out_ * self.mask_, self.coarse_img_)

        # compute regularization loss
        reg_data = self.pocs(out_).detach()
        reg_loss = self.loss_reg_fn(out_, reg_data)
        self.reg_data = u.torch_to_np(reg_data.squeeze(0), bc_del=False)
        
        # compute total loss
        if self.args.pocs_weight is None:
            eps = main_loss / reg_loss
            eps.detach()
        else:
            eps = self.args.reg_weight
        total_loss = main_loss + eps * reg_loss
        
        total_loss.backward()
        
        # save loss and metrics, and print log
        self.history.append((total_loss.item(),
                             main_loss.item(),
                             reg_loss.item(),
                             u.snr(output=out_, target=self.img_).item(),
                             u.pcorr(output=out_, target=self.img_).item()))
        self.history.lr.append(self.optimizer.param_groups[0]['lr'])
        print(colored(self.history.log_message(self.iiter), 'yellow'), '\r', end='')
        
        # save the output if the loss is decreasing
        if self.iiter == 0:
            self.loss_min = self.history.loss[-1]
            self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose(
                (1, 2, 0))
        elif self.history.loss[-1] <= self.loss_min:
            self.loss_min = self.history.loss[-1]
            self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose(
                (1, 2, 0))
        else:
            pass
        
        # saving intermediate outputs
        if self.iiter in self.iter_to_be_saved and self.iiter != 0:
            out_img = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose((1, 2, 0))
            np.save(os.path.join(self.outpath,
                                 self.image_name.split('.')[0] + '_output%s.npy' % str(self.iiter).zfill(self.zfill)),
                    out_img)
        
        self.iiter += 1
        
        return total_loss
예제 #8
0
import pywt

#Signal brut, non-bruité
plt.figure(figsize=(20, 3))
X = np.linspace(0, 10, 5000)
Y = np.load("/home/cardiologs/Workspace/denoising/code/data/[DE-IDENTIFIED]_#DICOM#_0000A0A5_0926_000E_0926_286321_635440525514417877.dcm.npz")
Y = Y["data"]
#plt.plot(X, Y-1)
#plt.title('Clear signal')

#Bruitage du signal par un bruit gaussien de variance sigma
sigma = 0.04
Yb = Y + sigma * np.random.standard_normal(Y.shape)
#plt.plot(X, Yb)
#plt.title('Noisy signal')
print utils.snr(Y, Yb)

##Choix du fltre et débruitage du bruit
w = pywt.Wavelet ('bior6.8')
#Yd = utils.wave_hard_filter(Yb, sigma, 20, w)
#plt.plot(X, Yd+1)
#plt.title("Denoised signal")
#plt.text(0, -1, "SNR :" + str(utils.snr(Y,Yd)))


#optimisation d'un paramètre du filtre
def filt (y, tau):
    return sc.invariant_wave_filter(y, sigma, tau, w)

parameters = np.linspace(1, 5, 600)
tauopt = utils.opt(parameters, filt, Yb, Y)
예제 #9
0
def show_results(res_dir: Path or str,
                 opts: dict = None,
                 curves: int = 0,
                 savefig=False):
    res_dir = Path(res_dir)
    args = u.read_args(res_dir / "args.txt")
    print(args.__dict__)

    inputs = np.load(os.path.join(args.imgdir, args.imgname),
                     allow_pickle=True)

    if opts is None:
        opts = dict()
    if 'clipval' not in opts.keys():
        opts['clipval'] = u.clim(inputs, 98)
    if 'save_opts' not in opts.keys():
        opts['save_opts'] = {
            'format': 'png',
            'dpi': 150,
            'bbox_inches': 'tight'
        }

    outputs, hist = reconstruct_patches(args,
                                        return_history=True,
                                        verbose=True)
    if outputs.shape != inputs.shape:
        print("\n\tWarning! Outputs and Inputs have different shape! %s - %s" %
              (outputs.shape, inputs.shape))
        inputs = inputs[:outputs.shape[0], :outputs.shape[1]]
        if inputs.ndim == 3:
            inputs = inputs[:, :, :outputs.shape[2]]

    # plot output volume
    if savefig:
        u.explode_volume(outputs, filename=res_dir / "output", **opts)
    else:
        u.explode_volume(outputs, **opts)

    # plot curves
    if curves > 0:
        if len(hist) <= curves:
            idx = range(len(hist))
        else:
            idx = sample(range(len(hist)), curves)
            idx.sort()

        fig, axs = plt.subplots(1, 4, figsize=(18, 4))

        for i in idx:
            axs[0].plot(hist[i].loss, label='patch %d' % i)
            axs[1].plot(hist[i].snr, label='patch %d' % i)
            axs[2].plot(hist[i].pcorr, label='patch %d' % i)
            try:
                axs[3].plot(hist[i].lr, label='patch %d' % i)
            except AttributeError:
                pass

        try:
            axs[0].set_title('LOSS %s' % args.loss)
        except AttributeError:
            axs[0].set_title('LOSS mae')
        axs[1].set_title('SNR = %.2f dB' % u.snr(outputs, inputs))
        axs[2].set_title('PCORR = %.2f %%' % (u.pcorr(outputs, inputs) * 100))
        axs[3].set_title('Learning Rate')

        for a in axs:
            a.legend()
            a.set_xlim(0, args.epochs)
            a.grid()

        axs[0].set_ylim(0)
        axs[1].set_ylim(0)
        axs[2].set_ylim(0, 1)
        axs[3].set_ylim(0, args.lr * 10)

        plt.suptitle(res_dir)
        plt.tight_layout(pad=.5)
        if savefig:
            plt.savefig(res_dir / f"curves.{opts['save_opts']['format']}",
                        **opts['save_opts'])
        plt.show()
예제 #10
0
import numpy as np
import cv2
from skimage.util import random_noise
import matplotlib
from utils import conv, median, snr, threshold

EXPORT_IMAGES = True

image = cv2.imread("../data/pup.jpg", cv2.IMREAD_GRAYSCALE) / 255.0

# Generate salt and pepper noise
seasoned_image = random_noise(image, mode='s&p', seed=0)
seasoned_snr = snr(image, seasoned_image)
print("Salt and pepper SNR: " + str(seasoned_snr) + " dB")

# Generate Gaussian noise
gaussed_image = random_noise(image, mode='gaussian', seed=0)
gaussed_snr = snr(image, gaussed_image)
print("Gaussian SNR: " + str(gaussed_snr) + " dB")

# Apply a median filter over image
# 5x5 averaging filter kernel (low pass)
avg_kernel = np.ones((5, 5)) / 25.0
averaged_simage = conv(seasoned_image, avg_kernel)
averaged_gimage = conv(gaussed_image, avg_kernel)

# Apply a median filter over image
median_simage = median(seasoned_image, 5)
median_gimage = median(gaussed_image, 5)

# Sobel edge detection filters
예제 #11
0
    def run_single_song(self,
                        hq_path,
                        filter_,
                        cutoff,
                        duration=None,
                        start=0,
                        save=True,
                        overwrite=True):
        """
        Runs the model for a single song.
        Chunks of audio is processed and the outputs are later concatenated to create full song.

        Args:
            hq_path (str): Path to high-quality audio
            filter_ (tuple): Type and order of lowpass filter to apply on hq audio
            cutoff (int): Cutoff frequency of the low-pass filter 
            duration (int, optional): Duration of audio to process. Defaults to None,
                which processes the entire audio.
            start (int, optional): Starting point, in seconds of audio to be processed. Defaults to 0.
            save (bool, optional): Setting False skips saving output, only calculates SNR and MSE. 
                Defaults to True.
            overwrite (bool, optional): To overwrite samples at different iterations during training. 
                Setting False can be useful to inspect generations during GAN training. Defaults to True.

        Returns:
            performance (dict): SNR and MSE values for input and output
        """

        self.gen_model.eval()  # switch model to inference mode
        with torch.no_grad():  # no training is done here
            # initialize dataloader
            song_data = SingleSong(c.WAV_SAMPLE_LEN,
                                   filter_,
                                   hq_path,
                                   cutoff=cutoff,
                                   duration=duration,
                                   start=start)
            song_loader = DataLoader(song_data,
                                     batch_size=c.WAV_BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=c.NUM_WORKERS)

            y_full = song_data.preallocate(
            )  # preallocation to keep individual output chunks
            song_averager = u.MovingAverages()

            # model works on chunks of audio, these are concatenated later
            idx_start_chunk = 0
            for x, t in song_loader:
                x = x.to(c.DEVICE)  # input
                t = t.to(c.DEVICE)  # target
                y = self.gen_model(x)  # output
                loss = F.mse_loss(y, t)
                song_averager({'loss': loss})
                idx_end_chunk = idx_start_chunk + y.shape[0]
                y_full[idx_start_chunk:idx_end_chunk] = y
                idx_start_chunk = idx_end_chunk

            y_full = u.torch2np(y_full)  # to cpu-numpy
            y_full = np.concatenate(y_full,
                                    axis=-1)  # create full song out of chunks

            x_full, t_full = song_data.get_full_signals()

            y_full = np.clip(y_full, -1, 1 - np.finfo(np.float32).eps)

            # Measure performance
            performance = song_averager.get()
            song_averager.reset()

            snr_ = u.snr(y_full, t_full)
            performance.update({'snr': snr_})

            if self.first_epoch:
                # Only need to see input SNR once
                snr_input = u.snr(x_full, t_full)
                performance.update({'input_snr': snr_input})

            if save:
                # Save audio
                song_name = hq_path.split('/')[-1].split('.')[0]
                if 'mixture' in song_name:  # DSD100 dataset has mixture.wav for all file names
                    song_name = hq_path.split('/')[
                        -2]  # folder name is song name
                # Remove problematic characters
                problem_str = [' - ', ' & ', ' &', '\'', ' ']
                for s in problem_str:
                    song_name = song_name.replace(s, '_')

                if not overwrite:
                    song_name = u.pad_str_zeros(str(self.iter_total),
                                                7) + '_' + song_name

                wavfile.write(
                    os.path.join(c.GENERATION_DIR,
                                 song_name + '_' + filter_[0] + '.wav'),
                    c.SAMPLE_RATE, y_full.T)

            return performance