def get_separated_audio(outputs, batch_data, opt): # fetch data and predictions mag_mix = batch_data['audio_mix_mags'] phase_mix = batch_data['audio_mix_phases'] pred_masks_ = outputs['pred_mask'] mag_mix_ = outputs['audio_mix_mags'] # unwarp log scale B = mag_mix.size(0) if opt.log_freq: grid_unwarp = torch.from_numpy( utils.warpgrid(B, opt.stft_frame // 2 + 1, pred_masks_.size(3), warp=False)).to(opt.device) pred_masks_linear = F.grid_sample(pred_masks_, grid_unwarp) else: pred_masks_linear = pred_masks_ # convert into numpy mag_mix = mag_mix.numpy() phase_mix = phase_mix.numpy() pred_masks_linear = pred_masks_linear.detach().cpu().numpy() pred_mag = mag_mix[0, 0] * pred_masks_linear[0, 0] preds_wav = utils.istft_reconstruction(pred_mag, phase_mix[0, 0], hop_length=opt.stft_hop, length=opt.audio_window) return preds_wav
def __init__(self, model, main_device=0): super(CUNetWrapper, self).__init__() self.L = len(SOURCES_SUBSET) self.model = model self.main_device = main_device self.grid_warp = torch.from_numpy( warpgrid(BATCH_SIZE, 256, STFT_WIDTH, warp=True)).to(self.main_device)
def forward(self, input): labels = input['labels'] labels = labels.squeeze(1).long() #covert back to longtensor vids = input['vids'] audio_mags = input['audio_mags'] audio_mix_mags = input['audio_mix_mags'] visuals = input['visuals'] audio_mix_mags = audio_mix_mags + 1e-10 # warp the spectrogram B = audio_mix_mags.size(0) T = audio_mix_mags.size(3) if self.opt.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to( self.opt.device) audio_mix_mags = F.grid_sample(audio_mix_mags, grid_warp) audio_mags = F.grid_sample(audio_mags, grid_warp) # calculate ground-truth masks gt_masks = audio_mags / audio_mix_mags # clamp to avoid large numbers in ratio masks gt_masks.clamp_(0., 5.) # pass through visual stream and extract visual features visual_feature = self.net_visual(Variable(visuals, requires_grad=False)) # audio-visual feature fusion through UNet and predict mask audio_log_mags = torch.log(audio_mix_mags).detach() mask_prediction = self.net_unet(audio_log_mags, visual_feature) # masking the spectrogram of mixed audio to perform separation separated_spectrogram = audio_mix_mags * mask_prediction # generate spectrogram for the classifier spectrogram2classify = torch.log(separated_spectrogram + 1e-10) #get log spectrogram # calculate loss weighting coefficient if self.opt.weighted_loss: weight = torch.log1p(audio_mix_mags) weight = torch.clamp(weight, 1e-3, 10) else: weight = None #classify the predicted spectrogram label_prediction = self.net_classifier(spectrogram2classify) output = {'gt_label': labels, 'pred_label': label_prediction, 'pred_mask': mask_prediction, 'gt_mask': gt_masks, \ 'pred_spectrogram': separated_spectrogram, 'visual_object': visuals, 'audio_mix_mags': audio_mix_mags, 'weight': weight, 'vids': vids} return output
def forward(self, x): if x.shape[0] == BATCH_SIZE: mags = F.grid_sample(x, self.grid_warp) else: # for the last batch, where the number of samples are generally lesser than the batch_size custom_grid_warp = torch.from_numpy( warpgrid(x.shape[0], 256, STFT_WIDTH, warp=True)).to(self.main_device) mags = F.grid_sample(x, custom_grid_warp) gt_masks = torch.div(mags[:, :-1], mags[:, -1].unsqueeze(1).expand(x.shape[0], self.L, *mags.shape[2:])) gt_masks.clamp_(0., 10.) gt_mags = x[:, :-1] mix_mag = x[:, -1].unsqueeze(1) pred_mags_sq = self.model(mags[:, -1].unsqueeze(1)) pred_mags_sq = torch.relu(pred_mags_sq) mag_mix_sq = mags[:, -1].unsqueeze(1) gt_mags_sq = gt_masks * mag_mix_sq network_output = [gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, gt_masks] # BxKx256x256, BxKx256x256, BxKx512x256, Bx1x512x256, BxKx256x256, BxKx256x256 return network_output
def save_visualization(vis_rows, outputs, batch_data, save_dir, opt): # fetch data and predictions mag_mix = batch_data['audio_mix_mags'] phase_mix = batch_data['audio_mix_phases'] visuals = batch_data['visuals'] pred_masks_ = outputs['pred_mask'] gt_masks_ = outputs['gt_mask'] mag_mix_ = outputs['audio_mix_mags'] weight_ = outputs['weight'] visual_object = outputs['visual_object'] gt_label = outputs['gt_label'] _, pred_label = torch.max(output['pred_label'], 1) label_list = ['Banjo', 'Cello', 'Drum', 'Guitar', 'Harp', 'Harmonica', 'Oboe', 'Piano', 'Saxophone', \ 'Trombone', 'Trumpet', 'Violin', 'Flute','Accordion', 'Horn'] # unwarp log scale B = mag_mix.size(0) if opt.log_freq: grid_unwarp = torch.from_numpy( utils.warpgrid(B, opt.stft_frame // 2 + 1, gt_masks_.size(3), warp=False)).to(opt.device) pred_masks_linear = F.grid_sample(pred_masks_, grid_unwarp) gt_masks_linear = F.grid_sample(gt_masks_, grid_unwarp) else: pred_masks_linear = pred_masks_ gt_masks_linear = gt_masks_ # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() pred_masks_ = pred_masks_.detach().cpu().numpy() pred_masks_linear = pred_masks_linear.detach().cpu().numpy() gt_masks_ = gt_masks_.detach().cpu().numpy() gt_masks_linear = gt_masks_linear.detach().cpu().numpy() visual_object = visual_object.detach().cpu().numpy() gt_label = gt_label.detach().cpu().numpy() pred_label = pred_label.detach().cpu().numpy() # loop over each example for j in range(min(B, opt.num_visualization_examples)): row_elements = [] # video names prefix = str(j) + '-' + label_list[int( gt_label[j])] + '-' + label_list[int(pred_label[j])] utils.mkdirs(os.path.join(save_dir, prefix)) # save mixture mix_wav = utils.istft_coseparation(mag_mix[j, 0], phase_mix[j, 0], hop_length=opt.stft_hop) mix_amp = utils.magnitude2heatmap(mag_mix_[j, 0]) weight = utils.magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') imsave(os.path.join(save_dir, filename_mixmag), mix_amp[::-1, :, :]) imsave(os.path.join(save_dir, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(save_dir, filename_mixwav), opt.audio_sampling_rate, mix_wav) row_elements += [{ 'text': prefix }, { 'image': filename_mixmag, 'audio': filename_mixwav }] # GT and predicted audio reconstruction gt_mag = mag_mix[j, 0] * gt_masks_linear[j, 0] gt_wav = utils.istft_coseparation(gt_mag, phase_mix[j, 0], hop_length=opt.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[j, 0] preds_wav = utils.istft_coseparation(pred_mag, phase_mix[j, 0], hop_length=opt.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask.jpg') filename_predmask = os.path.join(prefix, 'predmask.jpg') gt_mask = (np.clip(gt_masks_[j, 0], 0, 1) * 255).astype(np.uint8) pred_mask = (np.clip(pred_masks_[j, 0], 0, 1) * 255).astype(np.uint8) imsave(os.path.join(save_dir, filename_gtmask), gt_mask[::-1, :]) imsave(os.path.join(save_dir, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp.jpg') filename_predmag = os.path.join(prefix, 'predamp.jpg') gt_mag = utils.magnitude2heatmap(gt_mag) pred_mag = utils.magnitude2heatmap(pred_mag) imsave(os.path.join(save_dir, filename_gtmag), gt_mag[::-1, :, :]) imsave(os.path.join(save_dir, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt.wav') filename_predwav = os.path.join(prefix, 'pred.wav') wavfile.write(os.path.join(save_dir, filename_gtwav), opt.audio_sampling_rate, gt_wav) wavfile.write(os.path.join(save_dir, filename_predwav), opt.audio_sampling_rate, preds_wav) row_elements += [{ 'image': filename_predmag, 'audio': filename_predwav }, { 'image': filename_gtmag, 'audio': filename_gtwav }, { 'image': filename_predmask }, { 'image': filename_gtmask }] row_elements += [{'image': filename_weight}] vis_rows.append(row_elements)
def forward(self, input): labels = input['labels'] labels = labels.squeeze(1).long() #covert back to longtensor vids = input['vids'] audio_mags = input['audio_mags'] audio_mix_mags = input['audio_mix_mags'] visuals = input['visuals'] # visuals_256 = input['visuals_256'] audio_mix_mags = audio_mix_mags + 1e-10 '''1. warp the spectrogram''' B = audio_mix_mags.size(0) T = audio_mix_mags.size(3) if self.opt.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(self.opt.device) audio_mix_mags = F.grid_sample(audio_mix_mags, grid_warp) audio_mags = F.grid_sample(audio_mags, grid_warp) '''2. calculate ground-truth masks''' gt_masks = audio_mags / audio_mix_mags # clamp to avoid large numbers in ratio masks gt_masks.clamp_(0., 5.) '''3. pass through visual stream and extract visual features''' visual_feature, _ = self.net_visual(Variable(visuals, requires_grad=False)) '''4. audio-visual feature fusion through UNet and predict mask''' audio_log_mags = torch.log(audio_mix_mags).detach() # audio_norm_mags = torch.sigmoid(torch.log(audio_mags + 1e-10)) mask_prediction = self.net_unet(audio_log_mags, visual_feature) '''5. masking the spectrogram of mixed audio to perform separation and predict classification label''' separated_spectrogram = audio_mix_mags * mask_prediction # generate spectrogram for the classifier spectrogram2classify = torch.log(separated_spectrogram + 1e-10) # get log spectrogram # calculate loss weighting coefficient if self.opt.weighted_loss: weight = torch.log1p(audio_mix_mags) weight = torch.clamp(weight, 1e-3, 10) else: weight = None ''' 6.classify the predicted spectrogram''' ''' add audio feature after resnet18 layer4, 512*8*8''' ''' add output for classifier, output:label,feature(after layer4)''' label_prediction, _ = self.net_classifier(spectrogram2classify) # if self.opt.visual_unet_encoder: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None) #visuals_256) # elif self.opt.visual_cat: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, visual_feature) # else: # refine_mask, left_mask = self.refine_iteration(mask_prediction, audio_mix_mags, None) refine_masks = [None for i in range(self.opt.refine_iteration)] temp_mask = mask_prediction left_energy = [None for i in range(self.opt.refine_iteration)] for i in range(self.opt.refine_iteration): refine_mask, left_mask , left_mags = self.refine_iteration(temp_mask, audio_mix_mags, visual_feature) refine_masks[i] = refine_mask temp_mask = refine_mask left_energy[i] = torch.mean(left_mags) # refine后的频谱 refine_spec = audio_mix_mags * refine_mask # refine_norm_mags = torch.sigmoid(torch.log(refine_spec + 1e-10)) refine2classify = torch.log(refine_spec + 1e-10) _, fake_audio_feature = self.net_classifier(refine2classify) ''' 7. down channels for audio feature, for cal loss''' if self.opt.audio_extractor: real_audio_mags = torch.log(audio_mags + 1e-10) _ ,real_audio_feature = self.net_classifier(real_audio_mags) real_audio_feature = self.audio_extractor(real_audio_feature) fake_audio_feature = self.audio_extractor(fake_audio_feature) output = {'gt_label': labels, 'pred_label': label_prediction, 'pred_mask': mask_prediction, 'gt_mask': gt_masks, 'pred_spectrogram': separated_spectrogram, 'visual_object': visuals, 'audio_mags': audio_mags, 'audio_mix_mags': audio_mix_mags, 'weight': weight, 'vids': vids, 'refine_mask': refine_mask, 'refine_spec': refine_spec, 'left_mask':left_mask, 'refine_masks':refine_masks, 'left_mags': left_mags, 'left_energy':left_energy} if self.opt.audio_extractor: output['real_audio_feat'] = real_audio_feature output['fake_audio_feat'] = fake_audio_feature return output