Ejemplo n.º 1
0
 def sents_to_Tensors(batch_stacked_sents,
                      batch_labels=None,
                      toTorch=False):
     lengths = []
     embeds = []
     for batch_sents in batch_stacked_sents:
         lengths_sents = np.array([x.shape[0] for x in batch_sents])
         max_len = np.max(lengths_sents)
         sent_embeds = np.zeros((len(batch_sents), max_len), dtype=np.int32)
         for s_index, sent in enumerate(batch_sents):
             sent_embeds[s_index, :sent.shape[0]] = sent
         if toTorch:
             sent_embeds = torch.LongTensor(sent_embeds).to(get_device())
             lengths_sents = torch.LongTensor(lengths_sents).to(
                 get_device())
         lengths.append(lengths_sents)
         embeds.append(sent_embeds)
     if batch_labels is not None and toTorch:
         if isinstance(batch_labels[0], (list, np.ndarray)):
             padded_labels = np.zeros(
                 (len(batch_labels), max_len), dtype=np.int32) - 1
             for label_index, lab in enumerate(batch_labels):
                 padded_labels[label_index, :lab.shape[0]] = np.array(lab)
             batch_labels = padded_labels
         batch_labels = torch.LongTensor(np.array(batch_labels)).to(
             get_device())
     return embeds, lengths, batch_labels
Ejemplo n.º 2
0
def load_model(model_dir: str, do_lower_case: bool):
    logger.info("Loading model from %s", model_dir)
    # Load a trained model and vocabulary that you have fine-tuned
    model_pipeline = model.MODEL_CLASS.from_pretrained(model_dir)
    tokenizer = model.TOKENIZER_CLASS.from_pretrained(
        model_dir, do_lower_case=do_lower_case)
    model_pipeline.to(model.get_device())

    return model_pipeline, tokenizer
Ejemplo n.º 3
0
    def object_to_Tensors(some_object, toTorch=False):
        """wraps the given object in a torch tensor or numpy array.
		inputs:
			some_object (list(int), tuple(int), int), objec to be wraped.
			toTorch (bool), if True wraps some_object with a torch tensor, if False
				wraps it with an numpy array
		output:
			some_boject (torch.LongTensor(some_object) or np.array(some_object)), 
				if cuda is available, torch tensor is returned in cuda.
		"""
        if toTorch:
            some_object = torch.LongTensor(some_object).to(get_device())
        else:
            some_object = np.array(some_object)

        return some_object
Ejemplo n.º 4
0
images = []
for i in range(len(masks)):
    images.append(read_orig_image(None, 256))
try:
    masks = np.array(masks).transpose((0, 3, 1, 2))
    images = np.array(images).transpose((0, 3, 1, 2))
except ValueError:
    print("masks:", masks[0].shape)
    print("images:", images[0].shape)
    raise ValueError
mask_tens = torch.from_numpy(masks)
image_tens = torch.from_numpy(images)
print("done creating masks")

print("loading model....")
dev = get_device()
model = load_model(model_name)
applied_images = image_tens * mask_tens
applied_images = applied_images
mask_tens = mask_tens
print("done loading model from file")

print("running model....")
total = len(applied_images)

for i in range(len(applied_images)):
    ind = i + 1
    print(F"copying image {ind}/{total} to device...")
    ap_im = applied_images[i].unsqueeze(0).float().to(dev)
    msk = mask_tens[i].unsqueeze(0).float().to(dev)
    res = model(ap_im, msk)
Ejemplo n.º 5
0
 def _create_CrossEntropyLoss(weight=None, ignore_index=-1):
     loss_module = nn.CrossEntropyLoss(weight=weight,
                                       ignore_index=ignore_index).to(
                                           get_device())
     return loss_module
Ejemplo n.º 6
0
def train(model,
          dataset,
          from_epoch,
          batch_size,
          valid_loss_threshold,
          loss_coef,
          loss_fn,
          feat_loss_fn,
          optimizer,
          save_filename,
          dataset_path='.',
          use_elec=True,
          elec_only=False,
          use_zero_pad=False,
          early_stop_step=5):
    try:
        os.makedirs(os.path.dirname(save_filename))
    except:
        pass

    device = get_device(model)

    epoch = from_epoch + 1
    saved_epoch = from_epoch
    early_stop = 0
    min_valid_loss = valid_loss_threshold
    loss_hist = {'train loss': [], 'valid loss': []}

    if elec_only:
        train_noise_type = ['none']

    else:
        train_noise_type = TRAIN_NOISE_TYPE

    with tqdm(total=len(train_noise_type)) as pbar1, \
         tqdm(total=len(train_noise_type)) as pbar2:

        while True:
            # ===== Training =====
            loss_hist['train loss'] = []

            pbar1.reset()
            model.train()

            for noise_type in train_noise_type:
                bs = 0
                loss = 0

                pbar1.set_description_str(
                    f'(Epoch {epoch}) noise type: {noise_type}')
                random.shuffle(dataset['Train'])
                for sample_id, elec, clean_spec in dataset['Train']:

                    # data augmentation (segment)
                    seq_len = clean_spec.shape[1]
                    sub_len = random.randint(64, seq_len)
                    sub_from = random.randint(0, seq_len - sub_len)

                    clean_spec = clean_spec[:, sub_from:sub_from + sub_len, :]

                    if elec_only:
                        noisy_spec = None

                    else:
                        # load noisy training wave data and convert to spectrum
                        noisy_spec, _, _, _ = load_wave_data(
                            sample_id=sample_id,
                            noise_type=noise_type,
                            dataset_path=dataset_path,
                            norm=model.use_norm)
                        noisy_spec = noisy_spec[:, sub_from:sub_from + sub_len]
                        noisy_spec = torch.Tensor([noisy_spec.T]).to(device)

                    if not use_elec:
                        elec = None

                    else:
                        elec = elec[:, sub_from:sub_from + sub_len, :]

                    # random mask
                    if use_zero_pad and use_elec:
                        r = random.random() * 3
                        if r <= 1:
                            elec = torch.zeros(elec.shape).to(device)
                        elif r <= 2:
                            noisy_spec = torch.zeros(
                                noisy_spec.shape).to(device)

                    pred = model(noisy_spec, elec, elec_only)
                    loss += model.get_loss(loss_fn, feat_loss_fn, pred,
                                           clean_spec, elec, loss_coef)

                    bs += 1
                    if bs >= batch_size:
                        loss /= bs
                        train_loss = loss.item()

                        optimizer.zero_grad()
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                        optimizer.step()

                        loss_hist['train loss'].append(train_loss)
                        loss = 0
                        bs = 0

                    pbar1.refresh()

                if bs != 0:
                    loss /= bs
                    train_loss = loss.item()

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()

                    loss_hist['train loss'].append(train_loss)
                    loss = 0
                    bs = 0

                pbar1.set_postfix(loss=train_loss)
                pbar1.update()

                # save once for each batch
                model.save_model(f'{save_filename} test.pt', 0,
                                 valid_loss_threshold)

            # ===== Validation =====
            pbar2.reset()
            model.eval()
            valid_loss = 0
            valid_sample = 0
            for noise_type in train_noise_type:
                pbar2.set_description_str(
                    f'(Saved Epoch {saved_epoch}), min valid loss: {min_valid_loss:.3f}, noise type: {noise_type}'
                )
                for sample_id, elec, clean_spec in dataset['Valid']:

                    if elec_only:
                        noisy_spec = None

                    else:
                        noisy_spec, _, _, _ = load_wave_data(
                            sample_id=sample_id,
                            noise_type=noise_type,
                            dataset_path=dataset_path,
                            norm=model.use_norm)
                        noisy_spec = torch.Tensor([noisy_spec.T]).to(device)

                    if not use_elec:
                        elec = None

                    with torch.no_grad():
                        pred = model(noisy_spec, elec, elec_only)
                        valid_loss += model.get_loss(loss_fn, feat_loss_fn,
                                                     pred, clean_spec).item()
                    valid_sample += 1

                pbar2.set_postfix(valid_loss=valid_loss / valid_sample)
                pbar1.refresh()
                pbar2.update()

            valid_loss /= valid_sample
            loss_hist['valid loss'].append(valid_loss)

            # ===== Save model =====
            if valid_loss < min_valid_loss:
                min_valid_loss = valid_loss
                saved_epoch = epoch
                model.save_model(f'{save_filename}.pt', saved_epoch,
                                 valid_loss)
                early_stop = 0
            else:
                early_stop += 1

            pbar2.set_description_str(
                f'(Saved Epoch {saved_epoch}), min valid loss: {min_valid_loss:.3f}'
            )
            epoch += 1

            # ===== Plot loss =====
            plt.plot(loss_hist['train loss'])
            plt.plot(
                np.linspace(0, len(loss_hist['train loss']),
                            len(loss_hist['valid loss'])),
                loss_hist['valid loss'])
            plt.legend(['Train', 'Valid'])
            plt.tight_layout(pad=0.2)
            plt.show()

            # ===== Early stop =====
            if early_stop >= early_stop_step:
                pbar1.close()
                pbar2.close()
                break
Ejemplo n.º 7
0
def analyze(model,
            dataset,
            model_name,
            processes=None,
            use_S=True,
            elec_only=False,
            use_griffin=False,
            evaluation_path='Evaluation',
            dataset_path='.'):
    evaluation_dir = os.path.join(evaluation_path, model_name)
    try:
        os.makedirs(evaluation_dir)
    except:
        pass

    device = get_device(model)

    sr = 16000

    if not use_S or elec_only:
        test_noise_type = ['none']
        test_snr_type = ['none']

    else:
        test_noise_type = TEST_NOISE_TYPE
        test_snr_type = TEST_SNR_TYPE


    with Pool(processes) as p, \
            tqdm(test_noise_type) as noise_bar, \
            tqdm(total=len(TEST_SNR_TYPE)) as SNR_bar, \
            tqdm(total=len(dataset['Test'])) as test_bar:

        for noise_type in noise_bar:
            noise_bar.set_description(noise_type)
            result_file = os.path.join(evaluation_dir, f'{noise_type}.txt')
            open(result_file, 'w')

            SNR_bar.reset()
            for SNR_type in TEST_SNR_TYPE:
                SNR_bar.set_description(SNR_type)
                total_sample = 0
                folder_result = []

                test_bar.reset()
                for sample_id, elec, clean in dataset['Test']:
                    test_bar.set_description(
                        f'{to_TMHINT_name(sample_id)}.wav')

                    if use_S:
                        noisy, phasex, _, _ = load_wave_data(
                            sample_id=sample_id,
                            noise_type=noise_type,
                            SNR_type=SNR_type,
                            is_training=False,
                            dataset_path=dataset_path,
                            norm=model.use_norm)

                        noisy = torch.Tensor([noisy.T]).to(device)

                    elif elec_only:
                        noisy = None

                    else:
                        noisy = torch.zeros((1, elec.shape[1], 257)).to(device)

                    with torch.no_grad():
                        _, _, _, pred_y, _ = model(noisy, elec, elec_only)
                    pred_y = pred_y[0].cpu().detach().numpy().T

                    if not use_griffin and use_S:
                        enhanced = spec2wave(pred_y, phasex)
                    else:
                        enhanced = librosa.core.griffinlim(
                            10**(pred_y / 2),
                            n_iter=5,
                            hop_length=Const.HOP_LENGTH,
                            win_length=Const.WIN_LENGTH,
                            window=Const.WINDOW)

                    folder_result.append([
                        p.apply_async(pesq, (clean, enhanced, sr)),
                        p.apply_async(stoi, (clean, enhanced, sr, False)),
                        p.apply_async(stoi, (clean, enhanced, sr, True)),
                    ])

                    total_sample += 1
                    test_bar.refresh()

                if total_sample:
                    results = [0] * 3
                    for single_result in folder_result:
                        for i, result in enumerate(single_result):
                            results[i] += result.get()
                        noise_bar.refresh()
                        SNR_bar.refresh()
                        test_bar.update()

                    results = [_ / total_sample for _ in results]

                    with open(result_file, 'a') as writer:
                        writer.write(f'SNR: {SNR_type}\n')
                        writer.write(f'PESQ:  {results[0]}\n')
                        writer.write(f'STOI:  {results[1]}\n')
                        writer.write(f'ESTOI: {results[2]}\n')
                        writer.write('\n')

                SNR_bar.update()
            noise_bar.update()
Ejemplo n.º 8
0
def test(model,
         noise_type,
         SNR_type,
         test_sample,
         pca=None,
         elec_channel=(1, 124),
         dataset_path='.',
         use_S=True,
         use_E=True,
         elec_only=False,
         display_audio=False,
         show_graph=True,
         enhanced_path=None):
    print(f'{noise_type}, {SNR_type}, {test_sample}')

    device = get_device(model)
    if use_S:
        Sx, phasex, meanx, stdx = load_wave_data(sample_id=test_sample,
                                                 noise_type=noise_type,
                                                 SNR_type=SNR_type,
                                                 is_training=False,
                                                 dataset_path=dataset_path,
                                                 norm=model.use_norm)
        noisy = torch.Tensor([Sx.T]).to(device)
    else:
        Sx = None
        noisy = None

    Sy, phasey, _, _ = load_wave_data(sample_id=test_sample,
                                      is_training=False,
                                      dataset_path=dataset_path,
                                      norm=False)

    if use_E and model.is_use_E():
        elec_data = load_elec_data(test_sample, Sy.shape[1], elec_channel,
                                   dataset_path)
    else:
        elec_data = np.zeros((Sy.shape[1], 124))
    if pca:
        elec_data = pca.transform(elec_data)
    elec = torch.Tensor([elec_data]).to(device)
    elec_data = elec_data.T

    with torch.no_grad():
        Ss, Se, Sf, Sy_, e_ = model(noisy, elec, elec_only=elec_only)

    if Ss is not None:
        Ss = Ss[0].cpu().detach().numpy().T
    if Se is not None:
        Se = Se[0].cpu().detach().numpy().T
    if Sf is not None:
        Sf = Sf[0].cpu().detach().numpy().T
    if e_ is not None:
        e_ = e_[0].cpu().detach().numpy().T
    if Sy_ is not None:
        Sy_ = Sy_[0].cpu().detach().numpy().T
    else:
        return

    if noisy is not None:
        enhanced = spec2wave(Sy_, phasex)
    else:
        enhanced = librosa.core.griffinlim(10**(Sy_ / 2),
                                           n_iter=5,
                                           hop_length=Const.HOP_LENGTH,
                                           win_length=Const.WIN_LENGTH,
                                           window=Const.WINDOW)
    clean = spec2wave(Sy, phasey)

    if use_S:
        noisy = spec2wave(Sx, phasex, meanx, stdx)

    sr = 16000
    if _platform == 'Windows':
        print('PESQ: ',
              pesq_windows(clean, enhanced, test_sample, sr, dataset_path))
#     else:
    print('PESQ: ', pesq(clean, enhanced, sr))
    print('STOI: ', stoi(clean, enhanced, sr, False))
    print('ESTOI:', stoi(clean, enhanced, sr, True))

    saved_sr = 24000
    if enhanced_path is not None:
        test_wav_filename = os.path.join(enhanced_path,
                                         f'{to_TMHINT_name(test_sample)}.wav')
        enhanced = librosa.resample(enhanced, sr, saved_sr)
        wavfile.write(test_wav_filename, saved_sr, enhanced)
#         sf.write(test_wav_filename, enhanced, saved_sr, subtype='PCM_16')

# mel spectrogram
#         mel_basis = librosa.filters.mel(sr, n_fft, n_mels)  # (n_mels, 1+n_fft//2)
#         mel = np.dot(mel_basis, mag)  # (n_mels, t)

#         # to decibel
#         mel = 20 * np.log10(np.maximum(1e-5, mel))
#         mag = 20 * np.log10(np.maximum(1e-5, mag))

#         # normalize
#         mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1)
#         mag = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1)

#         # Transpose
#         mel = mel.T.astype(np.float32)  # (T, n_mels)
#         mag = mag.T.astype(np.float32)  # (T, 1+n_fft//2)

    if display_audio:
        display(Audio(clean, rate=sr, autoplay=False))
        if use_S:
            display(Audio(noisy, rate=sr, autoplay=False))

        if enhanced_path is None:
            display(Audio(enhanced, rate=sr, autoplay=False))
        else:
            display(Audio(enhanced, rate=saved_sr, autoplay=False))

    if show_graph:
        show_data = [
            (Sx, 'lower', 'jet'),
            (elec_data, 'lower', 'jet'),  #None, cm.Blues),
            (Ss, 'lower', 'jet'),
            (Se, 'lower', 'jet'),
            (Sf, 'lower', 'jet'),
            (Sy_, 'lower', 'jet'),
            (Sy, 'lower', 'jet'),
            #             (e_, None, cm.Blues),
        ]

        f, axes = plt.subplots(len(show_data),
                               1,
                               sharex=True,
                               figsize=(18, 12))
        axes[0].set_xlim(0, Sy.shape[1])

        for i, (data, origin, cmap) in enumerate(show_data):
            if data is not None:
                axes[i].imshow(data, origin=origin, aspect='auto', cmap=cmap)

        plt.tight_layout(pad=0.2)
        plt.show()