コード例 #1
0
def _calc_metrics(batch_data, outputs, args):
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    audios = batch_data['audios']
    pred_masks_ = outputs['pred_masks']
    N = args.num_mix
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    for n in range(N):
        pred_masks_linear[n] = pred_masks_[n]

    mag_mix = mag_mix.squeeze().cpu().numpy()
    phase_mix = phase_mix.numpy() if not args.use_mel else phase_mix
    for n in range(N):
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        if args.binary_mask:
            pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32)
    for j in range(B):
        preds_wav = [None for n in range(N)]
        for n in range(N):
            pred_mag = mag_mix[j] * pred_masks_linear[n][j]
            phase = phase_mix[j] if not args.use_mel else None
            preds_wav[n] = istft_reconstruction(pred_mag, phase, use_mel=args.use_mel, hop_length=args.stft_hop,
                                                sr=args.sr, n_fft=args.stft_frame, n_mels=256)  # phase_mix[j, 0]

        gts_wav = [None for n in range(N)]
        valid = True
        for n in range(N):
            gts_wav[n] = audios[n][j].numpy()
            valid *= np.sum(np.abs(gts_wav[n])) > 1e-5
            valid *= np.sum(np.abs(preds_wav[n])) > 1e-5
        if valid:
            sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav), np.asarray(preds_wav), False)
            sdr_meter.update(sdr.mean())
            sir_meter.update(sir.mean())
            sar_meter.update(sar.mean())
    return [sdr_meter.average(),
            sir_meter.average(),
            sar_meter.average()]
コード例 #2
0
ファイル: main_silent.py プロジェクト: YapengTian/CCOL-CVPR21
def output_visuals(vis_rows, batch_data, outputs, args):
    # fetch data and predictions
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    frames = batch_data['frames']
    infos = batch_data['infos']

    pred_masks_ = outputs['pred_masks']
    gt_masks_ = outputs['gt_masks']
    mag_mix_ = outputs['mag_mix']
    weight_ = outputs['weight']

    # unwarp log scale
    N = args.num_mix  #-1
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         gt_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(
                np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):
        row_elements = []

        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(
                infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :])
        imsave(os.path.join(args.vis, filename_weight), weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate,
                      mix_wav)
        row_elements += [{
            'text': prefix
        }, {
            'image': filename_mixmag,
            'audio': filename_mixwav
        }]

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):

            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag,
                                          phase_mix[j, 0],
                                          hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag,
                                                phase_mix[j, 0],
                                                hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix,
                                           'gtmask{}.jpg'.format(n + 1))
            filename_predmask = os.path.join(prefix,
                                             'predmask{}.jpg'.format(n + 1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :])
            imsave(os.path.join(args.vis, filename_predmask),
                   pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1))
            filename_predmag = os.path.join(prefix,
                                            'predamp{}.jpg'.format(n + 1))
            gt_mag = magnitude2heatmap(gt_mag)
            pred_mag = magnitude2heatmap(pred_mag)
            imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :])
            imsave(os.path.join(args.vis, filename_predmag),
                   pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate,
                          gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav),
                          args.audRate, preds_wav[n])

            # output video
            frames_tensor = [
                recover_rgb(frames[n][j, :, t]) for t in range(args.num_frames)
            ]
            frames_tensor = np.asarray(frames_tensor)
            path_video = os.path.join(args.vis, prefix,
                                      'video{}.mp4'.format(n + 1))
            save_video(path_video,
                       frames_tensor,
                       fps=args.frameRate / args.stride_frames)

            # combine gt video and audio
            filename_av = os.path.join(prefix, 'av{}.mp4'.format(n + 1))
            combine_video_audio(path_video,
                                os.path.join(args.vis, filename_gtwav),
                                os.path.join(args.vis, filename_av))

            row_elements += [{
                'video': filename_av
            }, {
                'image': filename_predmag,
                'audio': filename_predwav
            }, {
                'image': filename_gtmag,
                'audio': filename_gtwav
            }, {
                'image': filename_predmask
            }, {
                'image': filename_gtmask
            }]

        row_elements += [{'image': filename_weight}]
        vis_rows.append(row_elements)
コード例 #3
0
ファイル: main_silent.py プロジェクト: YapengTian/CCOL-CVPR21
def calc_metrics(batch_data, outputs, args):
    # meters
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()

    # fetch data and predictions
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    audios = batch_data['audios']

    pred_masks_ = outputs['pred_masks']

    # unwarp log scale
    N = 4  #args.num_mix-1
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         pred_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    phase_mix = phase_mix.numpy()
    for n in range(N):
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):
        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # Predicted audio recovery
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(
                pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) + 1e-6

        # separation performance computes
        L = preds_wav[0].shape[0]
        gts_wav = [None for n in range(N)]
        valid = True
        for n in range(N):
            gts_wav[n] = audios[n][j, 0:L].numpy() + 1e-6
            valid *= np.sum(np.abs(gts_wav[n])) > 1e-5
            valid *= np.sum(np.abs(preds_wav[n])) > 1e-5
        if valid:
            sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav),
                                                np.asarray(preds_wav), False)
            sdr_mix, _, _, _ = bss_eval_sources(
                np.asarray(gts_wav),
                np.asarray([mix_wav[0:L] for n in range(N)]), False)
            sdr_mix_meter.update(sdr_mix.mean())
            sdr_meter.update(sdr.mean())
            sir_meter.update(sir.mean())
            sar_meter.update(sar.mean())

    return [
        sdr_mix_meter.average(),
        sdr_meter.average(),
        sir_meter.average(),
        sar_meter.average()
    ]
def output_visuals_PosNeg(vis_rows, batch_data, masks_pos,  masks_neg, idx_pos, idx_neg, pred_masks_, gt_masks_, mag_mix_, weight_, args):
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    frames = batch_data['frames']
    infos = batch_data['infos']

    # masks to cpu, numpy
    masks_pos = torch.squeeze(masks_pos, dim=1)
    masks_pos = masks_pos.cpu().float().numpy()
    masks_neg = torch.squeeze(masks_neg, dim=1)
    masks_neg = masks_neg.cpu().float().numpy()

    N = args.num_mix
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]

    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B, args.stft_frame//2+1, gt_masks_[0].size(3), warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    idx_pos = int(idx_pos.detach().cpu().numpy())
    idx_neg = int(idx_neg.detach().cpu().numpy())
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()
        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32)

    threshold = 0.5
    # loop over each sample
    for j in range(B):
        row_elements = []
        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        matplotlib.image.imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :])
        matplotlib.image.imsave(os.path.join(args.vis, filename_weight), weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav)
        row_elements += [{'text': prefix}, {'image': filename_mixmag, 'audio': filename_mixwav}]

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_mag_ = mag_mix_[j, 0] * gt_masks_[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            pred_mag_ = mag_mix_[j, 0] * pred_masks_[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n+1))
            filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n+1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)

            matplotlib.image.imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :])
            matplotlib.image.imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n+1))
            filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n+1))
            gt_mag = magnitude2heatmap(gt_mag_)
            pred_mag = magnitude2heatmap(pred_mag_)

            matplotlib.image.imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :])
            matplotlib.image.imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n+1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n+1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n])

        # save frame
        frames_tensor = recover_rgb(frames[idx_pos][j,:,int(args.num_frames//2)])
        frames_tensor = np.asarray(frames_tensor)
        filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_pos+1))
        matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor)
        frame = frames_tensor.copy()
        # get heatmap and overlay for postive pair
        height, width = masks_pos.shape[-2:]
        heatmap = np.zeros((height*16, width*16))
        for i in range(height):
            for k in range(width):
                mask_pos = masks_pos[j]
                value = mask_pos[i,k]
                value = 0 if value < threshold else value
                ii = i * 16
                jj = k * 16
                heatmap[ii:ii + 16, jj:jj + 16] = value
        heatmap = (heatmap * 255).astype(np.uint8)
        filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_pos+1))
        plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot')
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F)
        path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_pos+1))
        cv2.imwrite(path_overlay, fin)

        # save frame
        frames_tensor = recover_rgb(frames[idx_neg][j,:,int(args.num_frames//2)])
        frames_tensor = np.asarray(frames_tensor)
        filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_neg+1))
        matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor)
        frame = frames_tensor.copy()
        # get heatmap and overlay for postive pair
        height, width = masks_neg.shape[-2:]
        heatmap = np.zeros((height*16, width*16))
        for i in range(height):
            for k in range(width):
                mask_neg = masks_neg[j]
                value = mask_neg[i,k]
                value = 0 if value < threshold else value
                ii = i * 16
                jj = k * 16
                heatmap[ii:ii + 16, jj:jj + 16] = value
        heatmap = (heatmap * 255).astype(np.uint8)
        filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_neg+1))
        plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot')
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F)
        path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_neg+1))
        cv2.imwrite(path_overlay, fin)

        vis_rows.append(row_elements)
コード例 #5
0
def output_visuals(batch_data, outputs, args):

    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    features = batch_data['feats']
    coords = batch_data['coords']
    infos = batch_data['infos']

    pred_masks_ = outputs['pred_masks']
    gt_masks_ = outputs['gt_masks']
    mag_mix_ = outputs['mag_mix']
    weight_ = outputs['weight']

    # unwarp log scale
    N = args.num_mix
    FN = args.num_frames
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         gt_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(
                np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):

        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(
                infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        imageio.imwrite(os.path.join(args.vis, filename_mixmag),
                        mix_amp[::-1, :, :])
        imageio.imwrite(os.path.join(args.vis, filename_weight),
                        weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate,
                      mix_wav)

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag,
                                          phase_mix[j, 0],
                                          hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag,
                                                phase_mix[j, 0],
                                                hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix,
                                           'gtmask{}.jpg'.format(n + 1))
            filename_predmask = os.path.join(prefix,
                                             'predmask{}.jpg'.format(n + 1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            imageio.imwrite(os.path.join(args.vis, filename_gtmask),
                            gt_mask[::-1, :])
            imageio.imwrite(os.path.join(args.vis, filename_predmask),
                            pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1))
            filename_predmag = os.path.join(prefix,
                                            'predamp{}.jpg'.format(n + 1))
            gt_mag = magnitude2heatmap(gt_mag)
            pred_mag = magnitude2heatmap(pred_mag)
            imageio.imwrite(os.path.join(args.vis, filename_gtmag),
                            gt_mag[::-1, :, :])
            imageio.imwrite(os.path.join(args.vis, filename_predmag),
                            pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate,
                          gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav),
                          args.audRate, preds_wav[n])

            #output pointclouds
            for f in range(FN):
                idx = torch.where(coords[n][f][:, 0] == j)
                path_point = os.path.join(
                    args.vis, prefix,
                    'point{}_frame{}.ply'.format(n + 1, f + 1))
                if args.rgbs_feature:
                    colors = np.asarray(features[n][f][idx])
                    xyz = np.asarray(coords[n][f][idx][:, 1:4])
                    xyz = xyz * args.voxel_size
                    save_points(path_point, xyz, colors)
                else:
                    xyz = np.asarray(features[n][f][idx])
                    save_points(path_point, xyz)