Beispiel #1
0
    def sample_stats(self,
                     stats_path='data',
                     sample_size=1000,
                     train_s_list=None,
                     train_d_list=None):
        """
		Computes statistics for each frequency component of the instantaneous a priori SNR
		in dB over a sample of the training set. The statistics are then used to map the
		instantaneous a priori SNR in dB between 0 and 1 using its cumulative distribution
		function. This forms the mapped a priori SNR (the training target).

		Argument/s:
			stats_path - path to the saved statistics.
			sample_size - number of training examples to compute the statistics from.
			train_s_list - train clean speech list.
			train_d_list - train noise list.
		"""
        if os.path.exists(stats_path + '/stats.npz'):
            print('Loading sample statistics...')
            with np.load(stats_path + '/stats.npz') as stats:
                self.mu = stats['mu_hat']
                self.sigma = stats['sigma_hat']
        elif train_s_list == None:
            raise ValueError(
                'No stats.npz file exists. data/stats.p is available here: https://github.com/anicolson/DeepXi/blob/master/data/stats.npz.'
            )
        else:
            print('Finding sample statistics...')
            s_sample_list = random.sample(self.train_s_list, sample_size)
            d_sample_list = random.sample(self.train_d_list, sample_size)
            s_sample, d_sample, s_sample_len, d_sample_len, snr_sample = self.wav_batch(
                s_sample_list, d_sample_list)
            snr_sample = np.array(
                random.choices(self.snr_levels, k=sample_size))
            # snr_sample = np.random.randint(self.min_snr, self.max_snr + 1, sample_size)
            samples = []
            for i in tqdm(range(s_sample.shape[0])):
                s_STMS, d_STMS, _, _ = self.mix(s_sample[i:i + 1],
                                                d_sample[i:i + 1],
                                                s_sample_len[i:i + 1],
                                                d_sample_len[i:i + 1],
                                                snr_sample[i:i + 1])
                xi_db = self.xi_db(s_STMS,
                                   d_STMS)  # instantaneous a priori SNR (dB).
                samples.append(np.squeeze(xi_db.numpy()))
            samples = np.vstack(samples)
            if len(samples.shape) != 2:
                raise ValueError('Incorrect shape for sample.')
            stats = {
                'mu_hat': np.mean(samples, axis=0),
                'sigma_hat': np.std(samples, axis=0)
            }
            self.mu, self.sigma = stats['mu_hat'], stats['sigma_hat']
            if not os.path.exists(stats_path): os.makedirs(stats_path)
            np.savez(stats_path + '/stats.npz',
                     mu_hat=stats['mu_hat'],
                     sigma_hat=stats['sigma_hat'])
            save_mat(stats_path + '/stats.mat', stats, 'stats')
            print('Sample statistics saved.')
Beispiel #2
0
    def sample(
        self,
        sample_size,
        sample_dir='data',
    ):
        """
		Gathers a sample of the training set. The sample can be used to compute
		statistics for mapping functions.

		Argument/s:
			sample_size - number of training examples included in the sample.
			sample_dir - path to the saved sample.
		"""
        sample_path = sample_dir + '/sample'
        if os.path.exists(sample_path + '.npz'):
            print('Loading sample...')
            with np.load(sample_path + '.npz') as sample:
                s_sample = sample['s_sample']
                d_sample = sample['d_sample']
                x_sample = sample['x_sample']
                wav_len = sample['wav_len']
        elif self.train_s_list == None:
            raise ValueError('No sample.npz file exists.')
        else:
            if sample_size == None: raise ValueError("sample_size is not set.")
            print('Gathering a sample of the training set...')
            s_sample_list = random.sample(self.train_s_list, sample_size)
            d_sample_list = random.sample(self.train_d_list, sample_size)
            s_sample_int, d_sample_int, s_sample_len, d_sample_len, snr_sample = self.wav_batch(
                s_sample_list, d_sample_list)
            s_sample = np.zeros_like(s_sample_int, np.float32)
            d_sample = np.zeros_like(s_sample_int, np.float32)
            x_sample = np.zeros_like(s_sample_int, np.float32)
            for i in tqdm(range(s_sample.shape[0])):
                s, d, x, _ = self.inp_tgt.mix(s_sample_int[i:i + 1],
                                              d_sample_int[i:i + 1],
                                              s_sample_len[i:i + 1],
                                              d_sample_len[i:i + 1],
                                              snr_sample[i:i + 1])
                s_sample[i, 0:s_sample_len[i]] = s
                d_sample[i, 0:s_sample_len[i]] = d
                x_sample[i, 0:s_sample_len[i]] = x
            wav_len = s_sample_len
            if not os.path.exists(sample_dir): os.makedirs(sample_dir)
            np.savez(sample_path + '.npz',
                     s_sample=s_sample,
                     d_sample=d_sample,
                     x_sample=x_sample,
                     wav_len=wav_len)
            sample = {
                's_sample': s_sample,
                'd_sample': d_sample,
                'x_sample': x_sample,
                'wav_len': wav_len
            }
            save_mat(sample_path + '.mat', sample, 'stats')
            print('Sample of the training set saved.')
        return s_sample, d_sample, x_sample, wav_len
Beispiel #3
0
	def sample_old(
		self,
		sample_size,
		sample_dir='data',
		):
		"""
		Gathers a sample of the training set. The sample can be used to compute
		statistics for mapping functions.

		Argument/s:
			sample_size - number of training examples included in the sample.
			sample_dir - path to the saved sample.
		"""

		sample_path = sample_dir + '/sample'
		if os.path.exists(sample_path + '.npz'):
			print('Loading sample...')
			with np.load(sample_path + '.npz') as sample:
				samples_s_STMS = sample['samples_s_STMS']
				samples_d_STMS = sample['samples_d_STMS']
				samples_x_STMS = sample['samples_x_STMS']
		elif self.train_s_list == None:
			raise ValueError('No sample.npz file exists. data/sample.npz is available here: https://github.com/anicolson/DeepXi/blob/master/data/sample.npz.')
		else:
			if sample_size == None: raise ValueError("sample_size is not set.")
			print('Gathering a sample of the training set...')
			s_sample_list = random.sample(self.train_s_list, sample_size)
			d_sample_list = random.sample(self.train_d_list, sample_size)
			s_sample, d_sample, s_sample_len, d_sample_len, snr_sample = self.wav_batch(s_sample_list, d_sample_list)
			snr_sample = np.array(random.choices(self.snr_levels, k=sample_size))
			samples_s_STMS = []
			samples_d_STMS = []
			samples_x_STMS = []
			for i in tqdm(range(s_sample.shape[0])):
				s, d, x, _ = self.inp_tgt.mix(s_sample[i:i+1], d_sample[i:i+1], s_sample_len[i:i+1],
					d_sample_len[i:i+1], snr_sample[i:i+1])
				s_STMS, _ = self.inp_tgt.polar_analysis(s)
				d_STMS, _ = self.inp_tgt.polar_analysis(d)
				x_STMS, _ = self.inp_tgt.polar_analysis(x)
				samples_s_STMS.append(np.squeeze(s_STMS.numpy()))
				samples_d_STMS.append(np.squeeze(d_STMS.numpy()))
				samples_x_STMS.append(np.squeeze(x_STMS.numpy()))
			samples_s_STMS = np.vstack(samples_s_STMS)
			samples_d_STMS = np.vstack(samples_d_STMS)
			samples_x_STMS = np.vstack(samples_x_STMS)
			if len(samples_s_STMS.shape) != 2: raise ValueError('Incorrect shape for s_STMS sample.')
			if len(samples_d_STMS.shape) != 2: raise ValueError('Incorrect shape for d_STMS sample.')
			if len(samples_x_STMS.shape) != 2: raise ValueError('Incorrect shape for x_STMS sample.')
			if not os.path.exists(sample_dir): os.makedirs(sample_dir)
			np.savez(sample_path + '.npz', samples_s_STMS=samples_s_STMS,
				samples_d_STMS=samples_d_STMS, samples_x_STMS=samples_x_STMS)
			sample = {'samples_s_STMS': samples_s_STMS,
				'samples_d_STMS': samples_d_STMS,
				'samples_x_STMS': samples_x_STMS}
			save_mat(sample_path + '.mat', sample, 'stats')
			print('Sample of the training set saved.')
		return samples_s_STMS, samples_d_STMS, samples_x_STMS
Beispiel #4
0
    def train(
        self,
        train_s_list,
        train_d_list,
        model_path='model',
        val_s=None,
        val_d=None,
        val_s_len=None,
        val_d_len=None,
        val_snr=None,
        val_flag=True,
        val_save_path=None,
        mbatch_size=8,
        max_epochs=200,
        resume_epoch=0,
        stats_path=None,
        sample_size=None,
        eval_example=False,
        save_model=True,
        log_iter=False,
    ):
        """
		Deep Xi training.

		Argument/s:
			train_s_list - clean-speech training list.
			train_d_list - noise training list.
			model_path - model save path.
			val_s - clean-speech validation batch.
			val_d - noise validation batch.
			val_s_len - clean-speech validation sequence length batch.
			val_d_len - noise validation sequence length batch.
			val_snr - SNR validation batch.
			val_flag - perform validation.
			val_save_path - validation batch save path.
			mbatch_size - mini-batch size.
			max_epochs - maximum number of epochs.
			resume_epoch - epoch to resume training from.
			stats_path - path to save sample statistics.
			sample_size - sample size.
			eval_example - evaluate a mini-batch of training examples.
			save_model - save architecture, weights, and training configuration.
			log_iter - log training loss for each training iteration.
		"""
        self.train_s_list = train_s_list
        self.train_d_list = train_d_list
        self.mbatch_size = mbatch_size
        self.n_examples = len(self.train_s_list)
        self.n_iter = math.ceil(self.n_examples / mbatch_size)

        self.sample_stats(stats_path, sample_size, train_s_list, train_d_list)

        train_dataset = self.dataset(max_epochs - resume_epoch)

        if val_flag:
            val_set = self.val_batch(val_save_path, val_s, val_d, val_s_len,
                                     val_d_len, val_snr)
            val_steps = len(val_set[0])
        else:
            val_set, val_steps = None, None

        if eval_example:
            print("Saving a mini-batch of training examples in .mat files...")
            x_STMS_batch, xi_bar_batch, seq_mask_batch = list(
                train_dataset.take(1).as_numpy_iterator())[0]
            save_mat('./x_STMS_batch.mat', x_STMS_batch, 'x_STMS_batch')
            save_mat('./xi_bar_batch.mat', xi_bar_batch, 'xi_bar_batch')
            save_mat('./seq_mask_batch.mat', seq_mask_batch, 'seq_mask_batch')
            print("Testing if add_noise() works correctly...")
            s, d, s_len, d_len, snr_tgt = self.wav_batch(
                train_s_list[0:mbatch_size], train_d_list[0:mbatch_size])
            (_, s, d) = self.add_noise_batch(self.normalise(s),
                                             self.normalise(d), s_len, d_len,
                                             snr_tgt)
            for (i, _) in enumerate(s):
                snr_act = self.snr_db(s[i][0:s_len[i]], d[i][0:d_len[i]])
                print('SNR target|actual: {:.2f}|{:.2f} (dB).'.format(
                    snr_tgt[i], snr_act))

        if not os.path.exists(model_path): os.makedirs(model_path)
        if not os.path.exists("log"): os.makedirs("log")
        if not os.path.exists("log/iter"): os.makedirs("log/iter")

        callbacks = []
        callbacks.append(
            CSVLogger("log/" + self.ver + ".csv", separator=',', append=True))
        if save_model: callbacks.append(SaveWeights(model_path))
        # if log_iter: callbacks.append(CSVLoggerIter("log/iter/" + self.ver + ".csv", separator=',', append=True))

        if resume_epoch > 0:
            self.model.load_weights(model_path + "/epoch-" +
                                    str(resume_epoch - 1) +
                                    "/variables/variables")

        self.model.compile(sample_weight_mode="temporal",
                           loss="binary_crossentropy",
                           optimizer=Adam(lr=0.001, clipvalue=1.0))

        self.model.fit(x=train_dataset,
                       initial_epoch=resume_epoch,
                       epochs=max_epochs,
                       steps_per_epoch=self.n_iter,
                       callbacks=callbacks,
                       validation_data=val_set,
                       validation_steps=val_steps)
Beispiel #5
0
    def infer(  ## NEED TO ADD DeepMMSE
        self,
        test_x,
        test_x_len,
        test_x_base_names,
        test_epoch,
        model_path='model',
        out_type='y',
        gain='mmse-lsa',
        out_path='out',
        stats_path=None,
        n_filters=40,
    ):
        """
		Deep Xi inference. The specified 'out_type' is saved.

		Argument/s:
			test_x - noisy-speech test batch.
			test_x_len - noisy-speech test batch lengths.
			test_x_base_names - noisy-speech base names.
			test_epoch - epoch to test.
			model_path - path to model directory.
			out_type - output type (see deepxi/args.py).
			gain - gain function (see deepxi/args.py).
			out_path - path to save output files.
			stats_path - path to the saved statistics.
		"""
        if out_type == 'xi_hat': out_path = out_path + '/xi_hat'
        elif out_type == 'y': out_path = out_path + '/y/' + gain
        elif out_type == 'deepmmse': out_path = out_path + '/deepmmse'
        elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat'
        elif out_type == 'subband_ibm_hat':
            out_path = out_path + '/subband_ibm_hat'
        else:
            raise ValueError('Invalid output type.')
        if not os.path.exists(out_path): os.makedirs(out_path)

        if test_epoch < 1:
            raise ValueError("test_epoch must be greater than 0.")

        # The mel-scale filter bank is to compute an ideal binary mask (IBM)
        # estimate for log-spectral subband energies (LSSE).
        if out_type == 'subband_ibm_hat':
            mel_filter_bank = self.mel_filter_bank(n_filters)

        self.sample_stats(stats_path)
        self.model.load_weights(model_path + '/epoch-' + str(test_epoch - 1) +
                                '/variables/variables')

        print("Processing observations...")
        x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch(
            test_x, test_x_len)
        print("Performing inference...")
        xi_bar_hat_batch = self.model.predict(x_STMS_batch,
                                              batch_size=1,
                                              verbose=1)

        print("Performing synthesis...")
        batch_size = len(test_x_len)
        for i in tqdm(range(batch_size)):
            base_name = test_x_base_names[i]
            x_STMS = x_STMS_batch[i, :n_frames[i], :]
            x_STPS = x_STPS_batch[i, :n_frames[i], :]
            xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :]
            xi_hat = self.xi_hat(xi_bar_hat)
            if out_type == 'xi_hat':
                save_mat(args.out_path + '/' + base_name + '.mat', xi_hat,
                         'xi_hat')
            elif out_type == 'y':
                y_STMS = np.multiply(x_STMS,
                                     gfunc(xi_hat, xi_hat + 1, gtype=gain))
                y = self.polar_synthesis(y_STMS, x_STPS).numpy()
                save_wav(out_path + '/' + base_name + '.wav', y, self.f_s)
            elif out_type == 'deepmmse':
                d_PSD_hat = np.multiply(
                    np.square(x_STMS),
                    gfunc(xi_hat, xi_hat + 1, gtype='deepmmse'))
                save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat,
                         'd_psd_hat')
            elif out_type == 'ibm_hat':
                ibm_hat = np.greater(xi_hat, 1.0).astype(bool)
                save_mat(out_path + '/' + base_name + '.mat', ibm_hat,
                         'ibm_hat')
            elif out_type == 'subband_ibm_hat':
                xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose())
                subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool)
                save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat,
                         'subband_ibm_hat')
            else:
                raise ValueError('Invalid output type.')
Beispiel #6
0
    def infer(
        self,
        test_x,
        test_x_len,
        test_x_base_names,
        test_epoch,
        model_path='model',
        out_type='y',
        gain='mmse-lsa',
        out_path='out',
        n_filters=40,
        saved_data_path=None,
    ):
        """
		Deep Xi inference. The specified 'out_type' is saved.

		Argument/s:
			test_x - noisy-speech test batch.
			test_x_len - noisy-speech test batch lengths.
			test_x_base_names - noisy-speech base names.
			test_epoch - epoch to test.
			model_path - path to model directory.
			out_type - output type (see deepxi/args.py).
			gain - gain function (see deepxi/args.py).
			out_path - path to save output files.
			saved_data_path - path to saved data necessary for enhancement.
		"""
        out_path_base = out_path
        if not isinstance(test_epoch, list): test_epoch = [test_epoch]
        if not isinstance(gain, list): gain = [gain]

        # The mel-scale filter bank is to compute an ideal binary mask (IBM)
        # estimate for log-spectral subband energies (LSSE).
        if out_type == 'subband_ibm_hat':
            mel_filter_bank = self.mel_filter_bank(n_filters)

        for e in test_epoch:
            if e < 1: raise ValueError("test_epoch must be greater than 0.")
            for g in gain:

                out_path = out_path_base + '/' + self.ver + '/' + 'e' + str(
                    e)  # output path.
                if out_type == 'xi_hat': out_path = out_path + '/xi_hat'
                elif out_type == 'gamma_hat':
                    out_path = out_path + '/gamma_hat'
                elif out_type == 's_STPS_hat':
                    out_path = out_path + '/s_STPS_hat'
                elif out_type == 'y':
                    if self.inp_tgt_type == 'MagIRM':
                        out_path = out_path + '/y'
                    else:
                        out_path = out_path + '/y/' + g
                elif out_type == 'deepmmse':
                    out_path = out_path + '/deepmmse'
                elif out_type == 'ibm_hat':
                    out_path = out_path + '/ibm_hat'
                elif out_type == 'subband_ibm_hat':
                    out_path = out_path + '/subband_ibm_hat'
                elif out_type == 'cd_hat':
                    out_path = out_path + '/cd_hat'
                else:
                    raise ValueError('Invalid output type.')
                if not os.path.exists(out_path): os.makedirs(out_path)

                self.model.load_weights(model_path + '/epoch-' + str(e - 1) +
                                        '/variables/variables')

                print("Processing observations...")
                inp_batch, supplementary_batch, n_frames = self.observation_batch(
                    test_x, test_x_len)

                print("Performing inference...")
                tgt_hat_batch = self.model.predict(inp_batch,
                                                   batch_size=1,
                                                   verbose=1)

                print("Saving outputs...")
                batch_size = len(test_x_len)
                for i in tqdm(range(batch_size)):
                    base_name = test_x_base_names[i]
                    inp = inp_batch[i, :n_frames[i], :]
                    tgt_hat = tgt_hat_batch[i, :n_frames[i], :]

                    # if tf.is_tensor(supplementary_batch):
                    supplementary = supplementary_batch[i, :n_frames[i], :]

                    if saved_data_path is not None:
                        saved_data = read_mat(saved_data_path + '/' +
                                              base_name + '.mat')
                        supplementary = (supplementary, saved_data)

                    if out_type == 'xi_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat', xi_hat,
                                 'xi_hat')
                    elif out_type == 'gamma_hat':
                        gamma_hat = self.inp_tgt.gamma_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 gamma_hat, 'gamma_hat')
                    elif out_type == 's_STPS_hat':
                        s_STPS_hat = self.inp_tgt.s_stps_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 s_STPS_hat, 's_STPS_hat')
                    elif out_type == 'y':
                        y = self.inp_tgt.enhanced_speech(
                            inp, supplementary, tgt_hat, g).numpy()
                        save_wav(out_path + '/' + base_name + '.wav', y,
                                 self.inp_tgt.f_s)
                    elif out_type == 'deepmmse':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        d_PSD_hat = np.multiply(
                            np.square(inp),
                            gfunc(xi_hat, xi_hat + 1.0, gtype='deepmmse'))
                        save_mat(out_path + '/' + base_name + '.mat',
                                 d_PSD_hat, 'd_psd_hat')
                    elif out_type == 'ibm_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        ibm_hat = np.greater(xi_hat, 1.0).astype(bool)
                        save_mat(out_path + '/' + base_name + '.mat', ibm_hat,
                                 'ibm_hat')
                    elif out_type == 'subband_ibm_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        xi_hat_subband = np.matmul(xi_hat,
                                                   mel_filter_bank.transpose())
                        subband_ibm_hat = np.greater(xi_hat_subband,
                                                     1.0).astype(bool)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 subband_ibm_hat, 'subband_ibm_hat')
                    elif out_type == 'cd_hat':
                        cd_hat = self.inp_tgt.cd_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat', cd_hat,
                                 'cd_hat')
                    else:
                        raise ValueError('Invalid output type.')
Beispiel #7
0
    def train(
        self,
        train_s_list,
        train_d_list,
        mbatch_size,
        max_epochs,
        loss_fnc,
        model_path='model',
        val_s=None,
        val_d=None,
        val_s_len=None,
        val_d_len=None,
        val_snr=None,
        val_flag=True,
        val_save_path=None,
        resume_epoch=0,
        eval_example=False,
        save_model=True,
        log_iter=False,
    ):
        """
		Deep Xi training.

		Argument/s:
			train_s_list - clean-speech training list.
			train_d_list - noise training list.
			model_path - model save path.
			val_s - clean-speech validation batch.
			val_d - noise validation batch.
			val_s_len - clean-speech validation sequence length batch.
			val_d_len - noise validation sequence length batch.
			val_snr - SNR validation batch.
			val_flag - perform validation.
			val_save_path - validation batch save path.
			mbatch_size - mini-batch size.
			max_epochs - maximum number of epochs.
			resume_epoch - epoch to resume training from.
			eval_example - evaluate a mini-batch of training examples.
			save_model - save architecture, weights, and training configuration.
			log_iter - log training loss for each training iteration.
			loss_fnc - loss function.
		"""
        self.train_s_list = train_s_list
        self.train_d_list = train_d_list
        self.mbatch_size = mbatch_size
        self.n_examples = len(self.train_s_list)
        self.n_iter = math.ceil(self.n_examples / mbatch_size)

        train_dataset = self.dataset(max_epochs - resume_epoch)

        if val_flag:
            val_set = self.val_batch(val_save_path, val_s, val_d, val_s_len,
                                     val_d_len, val_snr)
            val_steps = len(val_set[0])
        else:
            val_set, val_steps = None, None

        if not os.path.exists(model_path): os.makedirs(model_path)
        if not os.path.exists("log/loss"): os.makedirs("log/loss")

        callbacks = []
        callbacks.append(
            CSVLogger("log/loss/" + self.ver + ".csv",
                      separator=',',
                      append=True))
        if save_model: callbacks.append(SaveWeights(model_path))

        # if log_iter: callbacks.append(CSVLoggerIter("log/iter/" + self.ver + ".csv", separator=',', append=True))

        if resume_epoch > 0:
            self.model.load_weights(model_path + "/epoch-" +
                                    str(resume_epoch - 1) +
                                    "/variables/variables")

        if eval_example:
            print("Saving a mini-batch of training examples in .mat files...")
            inp_batch, tgt_batch, seq_mask_batch = list(
                train_dataset.take(1).as_numpy_iterator())[0]
            save_mat('./inp_batch.mat', inp_batch, 'inp_batch')
            save_mat('./tgt_batch.mat', tgt_batch, 'tgt_batch')
            save_mat('./seq_mask_batch.mat', seq_mask_batch, 'seq_mask_batch')
            print("Testing if add_noise() works correctly...")
            s, d, s_len, d_len, snr_tgt = self.wav_batch(
                self.train_s_list[0:mbatch_size],
                self.train_d_list[0:mbatch_size])
            (_, s, d) = self.inp_tgt.add_noise_batch(self.inp_tgt.normalise(s),
                                                     self.inp_tgt.normalise(d),
                                                     s_len, d_len, snr_tgt)
            for (i, _) in enumerate(s):
                snr_act = self.inp_tgt.snr_db(s[i][0:s_len[i]],
                                              d[i][0:d_len[i]])
                print('SNR target|actual: {:.2f}|{:.2f} (dB).'.format(
                    snr_tgt[i], snr_act))

        if self.network_type == "MHANet":
            lr_schedular = TransformerSchedular(self.network.d_model,
                                                self.network.warmup_steps)
            opt = Adam(learning_rate=lr_schedular,
                       clipvalue=1.0,
                       beta_1=0.9,
                       beta_2=0.98,
                       epsilon=1e-9)
        else:
            opt = Adam(learning_rate=0.001, clipvalue=1.0)

        if loss_fnc == "BinaryCrossentropy": loss = BinaryCrossentropy()
        elif loss_fnc == "MeanSquaredError": loss = MeanSquaredError()
        else: raise ValueError("Invalid loss function")

        self.model.compile(sample_weight_mode="temporal",
                           loss=loss,
                           optimizer=opt)
        print("SNR levels used for training:")
        print(self.snr_levels)
        self.model.fit(x=train_dataset,
                       initial_epoch=resume_epoch,
                       epochs=max_epochs,
                       steps_per_epoch=self.n_iter,
                       callbacks=callbacks,
                       validation_data=val_set,
                       validation_steps=val_steps)