コード例 #1
0
def get_separated_audio(outputs, batch_data, opt):
    # fetch data and predictions
    mag_mix = batch_data['audio_mix_mags']
    phase_mix = batch_data['audio_mix_phases']
    pred_masks_ = outputs['pred_mask']
    mag_mix_ = outputs['audio_mix_mags']
    # unwarp log scale
    B = mag_mix.size(0)
    if opt.log_freq:
        grid_unwarp = torch.from_numpy(
            utils.warpgrid(B,
                           opt.stft_frame // 2 + 1,
                           pred_masks_.size(3),
                           warp=False)).to(opt.device)
        pred_masks_linear = F.grid_sample(pred_masks_, grid_unwarp)
    else:
        pred_masks_linear = pred_masks_
    # convert into numpy
    mag_mix = mag_mix.numpy()
    phase_mix = phase_mix.numpy()
    pred_masks_linear = pred_masks_linear.detach().cpu().numpy()
    pred_mag = mag_mix[0, 0] * pred_masks_linear[0, 0]
    preds_wav = utils.istft_reconstruction(pred_mag,
                                           phase_mix[0, 0],
                                           hop_length=opt.stft_hop,
                                           length=opt.audio_window)
    return preds_wav
コード例 #2
0
 def __init__(self, model, main_device=0):
     super(CUNetWrapper, self).__init__()
     self.L = len(SOURCES_SUBSET)
     self.model = model
     self.main_device = main_device
     self.grid_warp = torch.from_numpy(
         warpgrid(BATCH_SIZE, 256, STFT_WIDTH, warp=True)).to(self.main_device)
コード例 #3
0
    def forward(self, input):
        labels = input['labels']
        labels = labels.squeeze(1).long()  #covert back to longtensor
        vids = input['vids']
        audio_mags = input['audio_mags']
        audio_mix_mags = input['audio_mix_mags']
        visuals = input['visuals']
        audio_mix_mags = audio_mix_mags + 1e-10

        # warp the spectrogram
        B = audio_mix_mags.size(0)
        T = audio_mix_mags.size(3)
        if self.opt.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(
                self.opt.device)
            audio_mix_mags = F.grid_sample(audio_mix_mags, grid_warp)
            audio_mags = F.grid_sample(audio_mags, grid_warp)

        # calculate ground-truth masks
        gt_masks = audio_mags / audio_mix_mags
        # clamp to avoid large numbers in ratio masks
        gt_masks.clamp_(0., 5.)

        # pass through visual stream and extract visual features
        visual_feature = self.net_visual(Variable(visuals,
                                                  requires_grad=False))

        # audio-visual feature fusion through UNet and predict mask
        audio_log_mags = torch.log(audio_mix_mags).detach()
        mask_prediction = self.net_unet(audio_log_mags, visual_feature)

        # masking the spectrogram of mixed audio to perform separation
        separated_spectrogram = audio_mix_mags * mask_prediction

        # generate spectrogram for the classifier
        spectrogram2classify = torch.log(separated_spectrogram +
                                         1e-10)  #get log spectrogram

        # calculate loss weighting coefficient
        if self.opt.weighted_loss:
            weight = torch.log1p(audio_mix_mags)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = None

        #classify the predicted spectrogram
        label_prediction = self.net_classifier(spectrogram2classify)

        output = {'gt_label': labels, 'pred_label': label_prediction, 'pred_mask': mask_prediction, 'gt_mask': gt_masks, \
                'pred_spectrogram': separated_spectrogram, 'visual_object': visuals, 'audio_mix_mags': audio_mix_mags, 'weight': weight, 'vids': vids}
        return output
コード例 #4
0
    def forward(self, x):
        if x.shape[0] == BATCH_SIZE:
            mags = F.grid_sample(x, self.grid_warp)
        else:  # for the last batch, where the number of samples are generally lesser than the batch_size
            custom_grid_warp = torch.from_numpy(
                warpgrid(x.shape[0], 256, STFT_WIDTH, warp=True)).to(self.main_device)
            mags = F.grid_sample(x, custom_grid_warp)

        gt_masks = torch.div(mags[:, :-1], mags[:, -1].unsqueeze(1).expand(x.shape[0], self.L, *mags.shape[2:]))
        gt_masks.clamp_(0., 10.)

        gt_mags = x[:, :-1]
        mix_mag = x[:, -1].unsqueeze(1)
        pred_mags_sq = self.model(mags[:, -1].unsqueeze(1))
        pred_mags_sq = torch.relu(pred_mags_sq)
        mag_mix_sq = mags[:, -1].unsqueeze(1)
        gt_mags_sq = gt_masks * mag_mix_sq

        network_output = [gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, gt_masks]  # BxKx256x256, BxKx256x256, BxKx512x256, Bx1x512x256, BxKx256x256, BxKx256x256
        return network_output
コード例 #5
0
def save_visualization(vis_rows, outputs, batch_data, save_dir, opt):
    # fetch data and predictions
    mag_mix = batch_data['audio_mix_mags']
    phase_mix = batch_data['audio_mix_phases']
    visuals = batch_data['visuals']

    pred_masks_ = outputs['pred_mask']
    gt_masks_ = outputs['gt_mask']
    mag_mix_ = outputs['audio_mix_mags']
    weight_ = outputs['weight']
    visual_object = outputs['visual_object']
    gt_label = outputs['gt_label']
    _, pred_label = torch.max(output['pred_label'], 1)
    label_list = ['Banjo', 'Cello', 'Drum', 'Guitar', 'Harp', 'Harmonica', 'Oboe', 'Piano', 'Saxophone', \
                    'Trombone', 'Trumpet', 'Violin', 'Flute','Accordion', 'Horn']

    # unwarp log scale
    B = mag_mix.size(0)
    if opt.log_freq:
        grid_unwarp = torch.from_numpy(
            utils.warpgrid(B,
                           opt.stft_frame // 2 + 1,
                           gt_masks_.size(3),
                           warp=False)).to(opt.device)
        pred_masks_linear = F.grid_sample(pred_masks_, grid_unwarp)
        gt_masks_linear = F.grid_sample(gt_masks_, grid_unwarp)
    else:
        pred_masks_linear = pred_masks_
        gt_masks_linear = gt_masks_

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    pred_masks_ = pred_masks_.detach().cpu().numpy()
    pred_masks_linear = pred_masks_linear.detach().cpu().numpy()
    gt_masks_ = gt_masks_.detach().cpu().numpy()
    gt_masks_linear = gt_masks_linear.detach().cpu().numpy()
    visual_object = visual_object.detach().cpu().numpy()
    gt_label = gt_label.detach().cpu().numpy()
    pred_label = pred_label.detach().cpu().numpy()

    # loop over each example
    for j in range(min(B, opt.num_visualization_examples)):
        row_elements = []

        # video names
        prefix = str(j) + '-' + label_list[int(
            gt_label[j])] + '-' + label_list[int(pred_label[j])]
        utils.mkdirs(os.path.join(save_dir, prefix))

        # save mixture
        mix_wav = utils.istft_coseparation(mag_mix[j, 0],
                                           phase_mix[j, 0],
                                           hop_length=opt.stft_hop)
        mix_amp = utils.magnitude2heatmap(mag_mix_[j, 0])
        weight = utils.magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        imsave(os.path.join(save_dir, filename_mixmag), mix_amp[::-1, :, :])
        imsave(os.path.join(save_dir, filename_weight), weight[::-1, :])
        wavfile.write(os.path.join(save_dir, filename_mixwav),
                      opt.audio_sampling_rate, mix_wav)
        row_elements += [{
            'text': prefix
        }, {
            'image': filename_mixmag,
            'audio': filename_mixwav
        }]

        # GT and predicted audio reconstruction
        gt_mag = mag_mix[j, 0] * gt_masks_linear[j, 0]
        gt_wav = utils.istft_coseparation(gt_mag,
                                          phase_mix[j, 0],
                                          hop_length=opt.stft_hop)
        pred_mag = mag_mix[j, 0] * pred_masks_linear[j, 0]
        preds_wav = utils.istft_coseparation(pred_mag,
                                             phase_mix[j, 0],
                                             hop_length=opt.stft_hop)

        # output masks
        filename_gtmask = os.path.join(prefix, 'gtmask.jpg')
        filename_predmask = os.path.join(prefix, 'predmask.jpg')
        gt_mask = (np.clip(gt_masks_[j, 0], 0, 1) * 255).astype(np.uint8)
        pred_mask = (np.clip(pred_masks_[j, 0], 0, 1) * 255).astype(np.uint8)
        imsave(os.path.join(save_dir, filename_gtmask), gt_mask[::-1, :])
        imsave(os.path.join(save_dir, filename_predmask), pred_mask[::-1, :])

        # ouput spectrogram (log of magnitude, show colormap)
        filename_gtmag = os.path.join(prefix, 'gtamp.jpg')
        filename_predmag = os.path.join(prefix, 'predamp.jpg')
        gt_mag = utils.magnitude2heatmap(gt_mag)
        pred_mag = utils.magnitude2heatmap(pred_mag)
        imsave(os.path.join(save_dir, filename_gtmag), gt_mag[::-1, :, :])
        imsave(os.path.join(save_dir, filename_predmag), pred_mag[::-1, :, :])

        # output audio
        filename_gtwav = os.path.join(prefix, 'gt.wav')
        filename_predwav = os.path.join(prefix, 'pred.wav')
        wavfile.write(os.path.join(save_dir, filename_gtwav),
                      opt.audio_sampling_rate, gt_wav)
        wavfile.write(os.path.join(save_dir, filename_predwav),
                      opt.audio_sampling_rate, preds_wav)

        row_elements += [{
            'image': filename_predmag,
            'audio': filename_predwav
        }, {
            'image': filename_gtmag,
            'audio': filename_gtwav
        }, {
            'image': filename_predmask
        }, {
            'image': filename_gtmask
        }]

        row_elements += [{'image': filename_weight}]
        vis_rows.append(row_elements)
    def forward(self, input):
        labels = input['labels']
        labels = labels.squeeze(1).long() #covert back to longtensor
        vids = input['vids']
        audio_mags = input['audio_mags']
        audio_mix_mags = input['audio_mix_mags']
        visuals = input['visuals']
        # visuals_256 = input['visuals_256']

        audio_mix_mags = audio_mix_mags + 1e-10

        '''1. warp the spectrogram'''
        B = audio_mix_mags.size(0)
        T = audio_mix_mags.size(3)
        if self.opt.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(self.opt.device)
            audio_mix_mags = F.grid_sample(audio_mix_mags, grid_warp)
            audio_mags = F.grid_sample(audio_mags, grid_warp)

        '''2. calculate ground-truth masks'''
        gt_masks = audio_mags / audio_mix_mags
        # clamp to avoid large numbers in ratio masks
        gt_masks.clamp_(0., 5.)

        '''3. pass through visual stream and extract visual features'''
        visual_feature, _ = self.net_visual(Variable(visuals, requires_grad=False))

        '''4. audio-visual feature fusion through UNet and predict mask'''
        audio_log_mags = torch.log(audio_mix_mags).detach()

        # audio_norm_mags = torch.sigmoid(torch.log(audio_mags + 1e-10))


        mask_prediction = self.net_unet(audio_log_mags, visual_feature)

        '''5. masking the spectrogram of mixed audio to perform separation and predict classification label'''
        separated_spectrogram = audio_mix_mags * mask_prediction

        # generate spectrogram for the classifier
        spectrogram2classify = torch.log(separated_spectrogram + 1e-10)  # get log spectrogram

        # calculate loss weighting coefficient
        if self.opt.weighted_loss:
            weight = torch.log1p(audio_mix_mags)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = None

        ''' 6.classify the predicted spectrogram'''
        ''' add audio feature after resnet18 layer4, 512*8*8'''
        ''' add output for classifier, output:label,feature(after layer4)'''
        label_prediction, _ = self.net_classifier(spectrogram2classify)


        # if self.opt.visual_unet_encoder:
        #     refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None) #visuals_256)
        # elif self.opt.visual_cat:
        #     refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, visual_feature)
        # else:
        #     refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None)
        refine_masks = [None for i in range(self.opt.refine_iteration)]
        temp_mask = mask_prediction
        left_energy = [None for i in range(self.opt.refine_iteration)]
        for i in range(self.opt.refine_iteration):
            refine_mask, left_mask , left_mags = self.refine_iteration(temp_mask, audio_mix_mags, visual_feature)
            refine_masks[i] = refine_mask
            temp_mask = refine_mask
            left_energy[i] = torch.mean(left_mags)




        # refine后的频谱
        refine_spec = audio_mix_mags * refine_mask
        # refine_norm_mags = torch.sigmoid(torch.log(refine_spec + 1e-10))

        refine2classify = torch.log(refine_spec + 1e-10)
        _, fake_audio_feature = self.net_classifier(refine2classify)

        ''' 7. down channels for audio feature, for cal loss'''
        if self.opt.audio_extractor:
            real_audio_mags = torch.log(audio_mags + 1e-10)
            _ ,real_audio_feature = self.net_classifier(real_audio_mags)
            real_audio_feature = self.audio_extractor(real_audio_feature)
            fake_audio_feature = self.audio_extractor(fake_audio_feature)


        output = {'gt_label': labels, 'pred_label': label_prediction, 'pred_mask': mask_prediction, 'gt_mask': gt_masks,
                'pred_spectrogram': separated_spectrogram, 'visual_object': visuals, 'audio_mags': audio_mags,
                  'audio_mix_mags': audio_mix_mags, 'weight': weight, 'vids': vids,
                  'refine_mask': refine_mask, 'refine_spec': refine_spec, 'left_mask':left_mask, 'refine_masks':refine_masks,
                  'left_mags': left_mags, 'left_energy':left_energy}
        if self.opt.audio_extractor:
            output['real_audio_feat'] = real_audio_feature
            output['fake_audio_feat'] = fake_audio_feature
        return output