def output_visuals_PosNeg(vis_rows, batch_data, masks_pos, masks_neg, idx_pos, idx_neg, pred_masks_, gt_masks_, mag_mix_, weight_, args): mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] frames = batch_data['frames'] infos = batch_data['infos'] # masks to cpu, numpy masks_pos = torch.squeeze(masks_pos, dim=1) masks_pos = masks_pos.cpu().float().numpy() masks_neg = torch.squeeze(masks_neg, dim=1) masks_neg = masks_neg.cpu().float().numpy() N = args.num_mix B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame//2+1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() idx_pos = int(idx_pos.detach().cpu().numpy()) idx_neg = int(idx_neg.detach().cpu().numpy()) for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) threshold = 0.5 # loop over each sample for j in range(B): row_elements = [] # video names prefix = [] for n in range(N): prefix.append('-'.join(infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') matplotlib.image.imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) row_elements += [{'text': prefix}, {'image': filename_mixmag, 'audio': filename_mixwav}] # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_mag_ = mag_mix_[j, 0] * gt_masks_[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] pred_mag_ = mag_mix_[j, 0] * pred_masks_[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n+1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n+1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8) matplotlib.image.imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n+1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n+1)) gt_mag = magnitude2heatmap(gt_mag_) pred_mag = magnitude2heatmap(pred_mag_) matplotlib.image.imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n+1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n+1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) # save frame frames_tensor = recover_rgb(frames[idx_pos][j,:,int(args.num_frames//2)]) frames_tensor = np.asarray(frames_tensor) filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_pos+1)) matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor) frame = frames_tensor.copy() # get heatmap and overlay for postive pair height, width = masks_pos.shape[-2:] heatmap = np.zeros((height*16, width*16)) for i in range(height): for k in range(width): mask_pos = masks_pos[j] value = mask_pos[i,k] value = 0 if value < threshold else value ii = i * 16 jj = k * 16 heatmap[ii:ii + 16, jj:jj + 16] = value heatmap = (heatmap * 255).astype(np.uint8) filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_pos+1)) plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot') heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F) path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_pos+1)) cv2.imwrite(path_overlay, fin) # save frame frames_tensor = recover_rgb(frames[idx_neg][j,:,int(args.num_frames//2)]) frames_tensor = np.asarray(frames_tensor) filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_neg+1)) matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor) frame = frames_tensor.copy() # get heatmap and overlay for postive pair height, width = masks_neg.shape[-2:] heatmap = np.zeros((height*16, width*16)) for i in range(height): for k in range(width): mask_neg = masks_neg[j] value = mask_neg[i,k] value = 0 if value < threshold else value ii = i * 16 jj = k * 16 heatmap[ii:ii + 16, jj:jj + 16] = value heatmap = (heatmap * 255).astype(np.uint8) filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_neg+1)) plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot') heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F) path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_neg+1)) cv2.imwrite(path_overlay, fin) vis_rows.append(row_elements)
def output_visuals(vis_rows, batch_data, outputs, args): # fetch data and predictions mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] frames = batch_data['frames'] infos = batch_data['infos'] pred_masks_ = outputs['pred_masks'] gt_masks_ = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # unwarp log scale N = args.num_mix #-1 B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype( np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): row_elements = [] # video names prefix = [] for n in range(N): prefix.append('-'.join( infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) imsave(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) row_elements += [{ 'text': prefix }, { 'image': filename_mixmag, 'audio': filename_mixwav }] # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n + 1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n + 1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n + 1)) gt_mag = magnitude2heatmap(gt_mag) pred_mag = magnitude2heatmap(pred_mag) imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) # output video frames_tensor = [ recover_rgb(frames[n][j, :, t]) for t in range(args.num_frames) ] frames_tensor = np.asarray(frames_tensor) path_video = os.path.join(args.vis, prefix, 'video{}.mp4'.format(n + 1)) save_video(path_video, frames_tensor, fps=args.frameRate / args.stride_frames) # combine gt video and audio filename_av = os.path.join(prefix, 'av{}.mp4'.format(n + 1)) combine_video_audio(path_video, os.path.join(args.vis, filename_gtwav), os.path.join(args.vis, filename_av)) row_elements += [{ 'video': filename_av }, { 'image': filename_predmag, 'audio': filename_predwav }, { 'image': filename_gtmag, 'audio': filename_gtwav }, { 'image': filename_predmask }, { 'image': filename_gtmask }] row_elements += [{'image': filename_weight}] vis_rows.append(row_elements)
def output_visuals(batch_data, outputs, args): mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] features = batch_data['feats'] coords = batch_data['coords'] infos = batch_data['infos'] pred_masks_ = outputs['pred_masks'] gt_masks_ = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # unwarp log scale N = args.num_mix FN = args.num_frames B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype( np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): # video names prefix = [] for n in range(N): prefix.append('-'.join( infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') imageio.imwrite(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) imageio.imwrite(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n + 1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n + 1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) imageio.imwrite(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) imageio.imwrite(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n + 1)) gt_mag = magnitude2heatmap(gt_mag) pred_mag = magnitude2heatmap(pred_mag) imageio.imwrite(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) imageio.imwrite(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) #output pointclouds for f in range(FN): idx = torch.where(coords[n][f][:, 0] == j) path_point = os.path.join( args.vis, prefix, 'point{}_frame{}.ply'.format(n + 1, f + 1)) if args.rgbs_feature: colors = np.asarray(features[n][f][idx]) xyz = np.asarray(coords[n][f][idx][:, 1:4]) xyz = xyz * args.voxel_size save_points(path_point, xyz, colors) else: xyz = np.asarray(features[n][f][idx]) save_points(path_point, xyz)