Esempio n. 1
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['appearance_imags']
        clip_imgs = batch_data['clips_frames']
        gt_masks = batch_data['masks']
        mag_mix = mag_mix + 1e-10
        N = args.num_mix
        weight = torch.ones_like(mag_mix)

        B = mag_mix.size(0)
        T = mag_mix.size(2)
        grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device)
        mag_mix = F.grid_sample(mag_mix.unsqueeze(1), grid_warp, align_corners=False).squeeze()

        mag_mix = torch.log(mag_mix).detach() if args.use_mel else mag_mix

        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        feat_motion = [None for n in range(N)]
        for n in range(N):
            feat_motion[n] = self.net_motion(clip_imgs[n])

        feat_sound = [None for n in range(N)]
        pred_masks = [None for n in range(N)]

        for n in range(N):
            feat_sound[n], _, _ = self.net_sound(mag_mix.to(args.device), feat_motion[n], feat_frames[n])
            pred_masks[n] = activate(feat_sound[n], args.sound_activation)
        for n in range(N):
            grid_warp = torch.from_numpy(warpgrid(B, args.stft_frame // 2 + 1, T, warp=False)).to(args.device)
            pred_masks[n] = F.grid_sample(pred_masks[n].unsqueeze(1), grid_warp, align_corners=False).squeeze()

        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
               {'pred_masks': pred_masks, 'gt_masks': gt_masks,
                'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
    def forward(self, mags, mag_mix, args):
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(
                warpgrid(B, 256, T, warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        return feat_sound, \
            {'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
Esempio n. 3
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        print(B)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        gt_masks[1] = torch.mul(gt_masks[3], 0.)
        gt_masks[3] = torch.mul(gt_masks[3], 0.)
        # LOG magnitude
        log_mag_mix = torch.log1p(mag_mix).detach()
        log_mag0 = torch.log1p(mags[0]).detach()
        log_mag1 = torch.log1p(mags[1]).detach()
        log_mag2 = torch.log1p(mags[2]).detach()
        #log_mag3 = torch.log1p(mags[3]).detach()
        #with torch.no_grad():
        # grounding
        feat_sound_ground = self.net_sound_ground(log_mag_mix)

        feat_frames_ground = [None for n in range(N)]
        for n in range(N):
            feat_frames_ground[n] = self.net_frame_ground.forward_multiframe(
                frames[n])

        # Grounding for sep
        g_sep = [None for n in range(N)]
        x = [None for n in range(N)]
        for n in range(N):
            g_sep[n] = self.net_grounding(feat_sound_ground,
                                          feat_frames_ground[n])
            x[n] = torch.softmax(g_sep[n].clone(), dim=-1)

        # Grounding module
        #feat_frame = (feat_frames_ground[0] + feat_frames_ground[1]) * 0.5
        g_pos = self.net_grounding(self.net_sound_ground(log_mag0),
                                   feat_frames_ground[0])
        g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0),
                                    feat_frames_ground[1])
        g_neg = self.net_grounding(self.net_sound_ground(log_mag2),
                                   feat_frames_ground[0])

        # Grounding for solo sound
        g_solo = [None for n in range(N)]
        g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[0])
        g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[1])
        g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[2])
        g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[3])
        for n in range(N):
            g_solo[n] = torch.softmax(g_solo[n], dim=-1)
        g = [
            torch.softmax(g_pos, dim=-1),
            torch.softmax(g_neg, dim=-1), x, g_solo
        ]

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # 3. sound synthesizer
        masks = [None for n in range(N)]
        for n in range(N):
            masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            masks[n] = activate(masks[n], args.output_activation)

        # 4. adjusted masks with grounding confidence scores
        pred_masks = [None for n in range(N)]

        s1 = masks[1].size(2)
        s2 = masks[1].size(3)

        if args.testing:
            # pred_masks[0] = masks[0]
            # pred_masks[1] = masks[1]
            # pred_masks[2] = masks[2]
            # pred_masks[3] = masks[3]
            pred_masks[0] = torch.mul(
                tf_data(g_sep[0], B, s1, s2).round(), masks[0])
            pred_masks[1] = torch.mul(
                tf_data(g_sep[1], B, s1, s2).round(), masks[1])
            pred_masks[2] = torch.mul(
                tf_data(g_sep[2], B, s1, s2).round(), masks[2])
            pred_masks[3] = torch.mul(
                tf_data(g_sep[3], B, s1, s2).round(), masks[3])
        else:
            pred_masks[0] = torch.mul(
                tf_data(g_sep[0], B, s1, s2).round(), masks[0]) + torch.mul(
                    tf_data(g_sep[1], B, s1, s2).round(), masks[1])
            pred_masks[1] = masks[1]
            pred_masks[2] = torch.mul(
                tf_data(g_sep[2], B, s1, s2).round(), masks[2]) + torch.mul(
                    tf_data(g_sep[3], B, s1, s2).round(), masks[3])
            pred_masks[3] = masks[3]
        # pred_masks[0] = masks[0]#+masks[1]
        # pred_masks[1] = masks[1]
        # pred_masks[2] = masks[2]#+masks[3]
        # pred_masks[3] = masks[3]

        # 5. loss
        loss_sep = 0.5 * (
            self.crit(pred_masks[0], gt_masks[0], weight).reshape(1) +
            self.crit(pred_masks[2], gt_masks[2], weight).reshape(1))

        df_mask = [None for n in range(N)]
        for i in range(N):
            df_mask[i] = torch.zeros(B, 1).cuda()

        for j in range(B):
            df_mask[0][j] = self.crit(masks[0][j:j + 1], gt_masks[0][j:j + 1],
                                      weight[j:j + 1])
            df_mask[1][j] = self.crit(masks[1][j:j + 1], gt_masks[0][j:j + 1],
                                      weight[j:j + 1])
            df_mask[2][j] = self.crit(masks[2][j:j + 1], gt_masks[2][j:j + 1],
                                      weight[j:j + 1])
            df_mask[3][j] = self.crit(masks[3][j:j + 1], gt_masks[2][j:j + 1],
                                      weight[j:j + 1])

        a1 = (1 - df_mask[0].div(df_mask[1] + df_mask[0]))
        a2 = (1 - df_mask[2].div(df_mask[2] + df_mask[3]))

        p = torch.zeros(B).cuda()
        sep_pos = [None for n in range(N)]
        for i in range(N):
            sep_pos[i] = torch.zeros(B, 1).cuda()

        for i in range(N):
            for j in range(B):
                sep_pos[i][j] = self.cts(g_sep[i][j:j + 1], p[j:j + 1].long())

        sep_neg = [None for n in range(N)]
        for i in range(N):
            sep_neg[i] = torch.zeros(B, 1).cuda()

        n = torch.ones(B).cuda()
        for i in range(N):
            for j in range(B):
                sep_neg[i][j] = self.cts(g_sep[i][j:j + 1], n[j:j + 1].long())

        cts_pos = torch.zeros(B).cuda()
        cts_pos1 = torch.zeros(B).cuda()
        for i in range(B):
            cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long())
            cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long())

        loss_grd = (cts_pos * a1.round() + cts_pos1 * (1 - a1).round()).mean(
        ) + self.cts(g_neg, n.long(
        ))  #torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long())

        loss_grd_pos = (sep_pos[0] * a1.round() + sep_pos[1] *
                        (1 - a1).round() + sep_pos[2] * a2.round() +
                        sep_pos[3] * (1 - a2).round()).mean()
        th = 0.1
        loss_grd_neg = (sep_neg[0] * (1 - a1).round() *
                        ((1 + torch.sign(df_mask[0] - th)) / 2) +
                        sep_neg[1] * a1.round() *
                        ((1 + torch.sign(df_mask[1] - th)) / 2) + sep_neg[2] *
                        (1 - a2).round() *
                        ((1 + torch.sign(df_mask[2] - th)) / 2) +
                        sep_neg[3] * a2.round() *
                        ((1 + torch.sign(df_mask[3] - th)) / 2)).mean()

        loss_grd_sep = (
            loss_grd_neg + loss_grd_pos
        ) * 0.25  #(torch.min(sep_pos[0], sep_pos[1]).mean() + torch.min(sep_pos[2], sep_pos[3]).mean())*0.5
        loss = loss_sep + loss_grd_sep

        err = loss + loss_grd * 0.5

        return err, loss_sep, g,\
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
Esempio n. 4
0
def output_visuals(vis_rows, batch_data, outputs, args):
    # fetch data and predictions
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    frames = batch_data['frames']
    infos = batch_data['infos']

    pred_masks_ = outputs['pred_masks']
    gt_masks_ = outputs['gt_masks']
    mag_mix_ = outputs['mag_mix']
    weight_ = outputs['weight']

    # unwarp log scale
    N = args.num_mix  #-1
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         gt_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(
                np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):
        row_elements = []

        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(
                infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :])
        imsave(os.path.join(args.vis, filename_weight), weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate,
                      mix_wav)
        row_elements += [{
            'text': prefix
        }, {
            'image': filename_mixmag,
            'audio': filename_mixwav
        }]

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):

            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag,
                                          phase_mix[j, 0],
                                          hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag,
                                                phase_mix[j, 0],
                                                hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix,
                                           'gtmask{}.jpg'.format(n + 1))
            filename_predmask = os.path.join(prefix,
                                             'predmask{}.jpg'.format(n + 1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :])
            imsave(os.path.join(args.vis, filename_predmask),
                   pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1))
            filename_predmag = os.path.join(prefix,
                                            'predamp{}.jpg'.format(n + 1))
            gt_mag = magnitude2heatmap(gt_mag)
            pred_mag = magnitude2heatmap(pred_mag)
            imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :])
            imsave(os.path.join(args.vis, filename_predmag),
                   pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate,
                          gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav),
                          args.audRate, preds_wav[n])

            # output video
            frames_tensor = [
                recover_rgb(frames[n][j, :, t]) for t in range(args.num_frames)
            ]
            frames_tensor = np.asarray(frames_tensor)
            path_video = os.path.join(args.vis, prefix,
                                      'video{}.mp4'.format(n + 1))
            save_video(path_video,
                       frames_tensor,
                       fps=args.frameRate / args.stride_frames)

            # combine gt video and audio
            filename_av = os.path.join(prefix, 'av{}.mp4'.format(n + 1))
            combine_video_audio(path_video,
                                os.path.join(args.vis, filename_gtwav),
                                os.path.join(args.vis, filename_av))

            row_elements += [{
                'video': filename_av
            }, {
                'image': filename_predmag,
                'audio': filename_predwav
            }, {
                'image': filename_gtmag,
                'audio': filename_gtwav
            }, {
                'image': filename_predmask
            }, {
                'image': filename_gtmask
            }]

        row_elements += [{'image': filename_weight}]
        vis_rows.append(row_elements)
Esempio n. 5
0
def calc_metrics(batch_data, outputs, args):
    # meters
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()

    # fetch data and predictions
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    audios = batch_data['audios']

    pred_masks_ = outputs['pred_masks']

    # unwarp log scale
    N = 4  #args.num_mix-1
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         pred_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    phase_mix = phase_mix.numpy()
    for n in range(N):
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):
        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # Predicted audio recovery
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(
                pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) + 1e-6

        # separation performance computes
        L = preds_wav[0].shape[0]
        gts_wav = [None for n in range(N)]
        valid = True
        for n in range(N):
            gts_wav[n] = audios[n][j, 0:L].numpy() + 1e-6
            valid *= np.sum(np.abs(gts_wav[n])) > 1e-5
            valid *= np.sum(np.abs(preds_wav[n])) > 1e-5
        if valid:
            sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav),
                                                np.asarray(preds_wav), False)
            sdr_mix, _, _, _ = bss_eval_sources(
                np.asarray(gts_wav),
                np.asarray([mix_wav[0:L] for n in range(N)]), False)
            sdr_mix_meter.update(sdr_mix.mean())
            sdr_meter.update(sdr.mean())
            sir_meter.update(sir.mean())
            sar_meter.update(sar.mean())

    return [
        sdr_mix_meter.average(),
        sdr_meter.average(),
        sir_meter.average(),
        sar_meter.average()
    ]
Esempio n. 6
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # 3. sound synthesizer
        pred_masks = [None for n in range(N)]
        for n in range(N):
            pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # 4. loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        frame_emb = batch_data['frame_emb']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # separating sound
        sound_size = feat_sound.size()
        B, C = sound_size[0], sound_size[1]
        pred_masks = [None for n in range(N)]
        for n in range(N):
            feat_img = frame_emb[n].float()
            feat_img = feat_img.view(B, 1, C)
            pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \
                .view(B, 1, *sound_size[2:])
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def output_visuals_PosNeg(vis_rows, batch_data, masks_pos,  masks_neg, idx_pos, idx_neg, pred_masks_, gt_masks_, mag_mix_, weight_, args):
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    frames = batch_data['frames']
    infos = batch_data['infos']

    # masks to cpu, numpy
    masks_pos = torch.squeeze(masks_pos, dim=1)
    masks_pos = masks_pos.cpu().float().numpy()
    masks_neg = torch.squeeze(masks_neg, dim=1)
    masks_neg = masks_neg.cpu().float().numpy()

    N = args.num_mix
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]

    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B, args.stft_frame//2+1, gt_masks_[0].size(3), warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    idx_pos = int(idx_pos.detach().cpu().numpy())
    idx_neg = int(idx_neg.detach().cpu().numpy())
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()
        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32)

    threshold = 0.5
    # loop over each sample
    for j in range(B):
        row_elements = []
        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        matplotlib.image.imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :])
        matplotlib.image.imsave(os.path.join(args.vis, filename_weight), weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav)
        row_elements += [{'text': prefix}, {'image': filename_mixmag, 'audio': filename_mixwav}]

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_mag_ = mag_mix_[j, 0] * gt_masks_[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            pred_mag_ = mag_mix_[j, 0] * pred_masks_[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n+1))
            filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n+1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8)

            matplotlib.image.imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :])
            matplotlib.image.imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n+1))
            filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n+1))
            gt_mag = magnitude2heatmap(gt_mag_)
            pred_mag = magnitude2heatmap(pred_mag_)

            matplotlib.image.imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :])
            matplotlib.image.imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n+1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n+1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n])

        # save frame
        frames_tensor = recover_rgb(frames[idx_pos][j,:,int(args.num_frames//2)])
        frames_tensor = np.asarray(frames_tensor)
        filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_pos+1))
        matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor)
        frame = frames_tensor.copy()
        # get heatmap and overlay for postive pair
        height, width = masks_pos.shape[-2:]
        heatmap = np.zeros((height*16, width*16))
        for i in range(height):
            for k in range(width):
                mask_pos = masks_pos[j]
                value = mask_pos[i,k]
                value = 0 if value < threshold else value
                ii = i * 16
                jj = k * 16
                heatmap[ii:ii + 16, jj:jj + 16] = value
        heatmap = (heatmap * 255).astype(np.uint8)
        filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_pos+1))
        plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot')
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F)
        path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_pos+1))
        cv2.imwrite(path_overlay, fin)

        # save frame
        frames_tensor = recover_rgb(frames[idx_neg][j,:,int(args.num_frames//2)])
        frames_tensor = np.asarray(frames_tensor)
        filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_neg+1))
        matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor)
        frame = frames_tensor.copy()
        # get heatmap and overlay for postive pair
        height, width = masks_neg.shape[-2:]
        heatmap = np.zeros((height*16, width*16))
        for i in range(height):
            for k in range(width):
                mask_neg = masks_neg[j]
                value = mask_neg[i,k]
                value = 0 if value < threshold else value
                ii = i * 16
                jj = k * 16
                heatmap[ii:ii + 16, jj:jj + 16] = value
        heatmap = (heatmap * 255).astype(np.uint8)
        filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_neg+1))
        plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot')
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F)
        path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_neg+1))
        cv2.imwrite(path_overlay, fin)

        vis_rows.append(row_elements)
Esempio n. 9
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        gt_masks[1] = torch.mul(gt_masks[3], 0.)
        gt_masks[3] = torch.mul(gt_masks[3], 0.)

        # LOG magnitude
        log_mag_mix = torch.log1p(mag_mix).detach()
        log_mag0 = torch.log1p(mags[0]).detach()
        log_mag2 = torch.log1p(mags[2]).detach()

        # grounding
        feat_sound_ground = self.net_sound_ground(log_mag_mix)

        feat_frames_ground = [None for n in range(N)]
        for n in range(N):
            feat_frames_ground[n] = self.net_frame_ground.forward_multiframe(
                frames[n])

        # Grounding for sep
        g_sep = [None for n in range(N)]
        x = [None for n in range(N)]
        for n in range(N):
            g_sep[n] = self.net_grounding(feat_sound_ground,
                                          feat_frames_ground[n])
            x[n] = torch.softmax(g_sep[n].clone(), dim=-1)

        # Grounding module
        g_pos = self.net_grounding(self.net_sound_ground(log_mag0),
                                   feat_frames_ground[0])
        g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0),
                                    feat_frames_ground[1])
        g_neg = self.net_grounding(self.net_sound_ground(log_mag2),
                                   feat_frames_ground[0])

        # Grounding for solo sound
        g_solo = [None for n in range(N)]
        g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[0])
        g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[1])
        g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[2])
        g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[3])
        for n in range(N):
            g_solo[n] = torch.softmax(g_solo[n], dim=-1)
        g = [
            torch.softmax(g_pos, dim=-1),
            torch.softmax(g_neg, dim=-1), x, g_solo
        ]

        p = torch.zeros(B).cuda()
        n = torch.ones(B).cuda()

        cts_pos = torch.zeros(B).cuda()
        cts_pos1 = torch.zeros(B).cuda()
        for i in range(B):
            cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long())
            cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long())

        # 5. loss
        err = torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long()) \
              #+ torch.min(self.cts(g_sep[0], p.long()), self.cts(g_sep[1], p.long())) + torch.min(self.cts(g_sep[2], p.long()), self.cts(g_sep[3], p.long()))

        return err, g
Esempio n. 10
0
    def forward(self, batch_data, args):
        ### prepare data
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10
        mag_mix_tmp = mag_mix.clone()

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        # Please notice that, gt_masks are unordered
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 2.)

        ### Minus part
        if 'Minus' not in self.mode:
            self.requires_grad(self.net_sound_M, False)
            self.requires_grad(self.net_frame_M, False)

        feat_frames = [None for n in range(N)]
        ordered_pred_masks = [None for n in range(N)]
        ordered_pred_mags = [None for n in range(N)]

        # Step1: obtain all the frame features
        # forward net_frame_M -> Bx1xC
        for n in range(N):
            log_mag_mix = torch.log(mag_mix)
            feat_frames[n] = self.net_frame_M.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # Step2: separate the sounds one by one
        # forward net_sound_M -> BxCxHxW
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B, args.stft_frame // 2 + 1, 256,
                         warp=False)).to(args.device)
        index_record = []
        for n in range(N):
            log_mag_mix = torch.log(mag_mix).detach()
            feat_sound = self.net_sound_M(log_mag_mix)
            _, C, H, W = feat_sound.shape
            feat_sound = feat_sound.view(B, C, -1)

            # obtain current separated sound
            energy_list = []
            tmp_masks = []
            tmp_pred_mags = []

            for feat_frame in feat_frames:
                cur_pred_mask = torch.bmm(feat_frame.unsqueeze(1),
                                          feat_sound).view(B, 1, H, W)
                cur_pred_mask = activate(cur_pred_mask, args.output_activation)
                tmp_masks.append(cur_pred_mask)
                # Here we cut off the loss flow from Minus net to Plus net
                # in order to train more steadily
                if args.log_freq:
                    cur_pred_mask_unwrap = F.grid_sample(
                        cur_pred_mask.detach(), grid_unwarp)
                    if args.binary_mask:
                        cur_pred_mask_unwrap = (cur_pred_mask_unwrap >
                                                args.mask_thres).float()
                else:
                    cur_pred_mask_unwrap = cur_pred_mask.detach()
                cur_pred_mag = cur_pred_mask_unwrap * mag_mix_tmp
                tmp_pred_mags.append(cur_pred_mag)
                energy_list.append(
                    np.array(cur_pred_mag.view(B, -1).mean(dim=1).cpu().data))

            total_energy = np.stack(energy_list, axis=1)
            # _, cur_index = torch.max(total_energy)
            cur_index = self.choose_max(index_record, total_energy)
            index_record.append(cur_index)

            masks = torch.stack(tmp_masks, dim=0)
            ordered_pred_masks[n] = masks[cur_index, list(range(B))]
            pred_mags = torch.stack(tmp_pred_mags, dim=0)
            ordered_pred_mags[n] = pred_mags[cur_index, list(range(B))]

            #log_mag_mix = log_mag_mix - log_mag_mix * pred_masks[n]
            mag_mix = mag_mix - mag_mix * ordered_pred_masks[n] + 1e-10

        # just for swap pred_masks, in order to compute loss conveniently
        # since gt_masks are unordered, we must transfer ordered_pred_masks to unordered
        index_record = np.stack(index_record, axis=1)
        total_masks = torch.stack(ordered_pred_masks, dim=1)
        total_pred_mags = torch.stack(ordered_pred_mags, dim=1)
        unordered_pred_masks = []
        unordered_pred_mags = []
        for n in range(N):
            mask_index = np.where(index_record == n)
            if args.binary_mask:
                unordered_pred_masks.append(total_masks[mask_index])
                unordered_pred_mags.append(total_pred_mags[mask_index])
            else:
                unordered_pred_masks.append(total_masks[mask_index] * 2)
                unordered_pred_mags.append(total_pred_mags[mask_index])

        ### Plus part
        if 'Plus' in self.mode:
            pre_sum = torch.zeros_like(unordered_pred_masks[0]).to(args.device)
            Plus_pred_masks = []
            for n in range(N):
                unordered_pred_mag = unordered_pred_mags[n].log()
                unordered_pred_mag = F.grid_sample(unordered_pred_mag,
                                                   grid_warp)
                input_concat = torch.cat((pre_sum, unordered_pred_mag), dim=1)

                residual_mask = activate(self.net_sound_P(input_concat),
                                         args.sound_activation)
                Plus_pred_masks.append(unordered_pred_masks[n] + residual_mask)

                pre_sum = pre_sum.sum(dim=1, keepdim=True).detach()

            unordered_pred_masks = Plus_pred_masks

        # loss
        if args.need_loss_ratio:
            err = 0
            for n in range(N):
                err += self.crit(unordered_pred_masks[n], gt_masks[n],
                                 weight) / N * 2**(n - 1)
        else:
            err = self.crit(unordered_pred_masks, gt_masks, weight).reshape(1)

        if 'Minus' in self.mode:
            res_mag_mix = torch.exp(log_mag_mix)
            err_remain = torch.mean(weight *
                                    torch.clamp(res_mag_mix - 1e-2, min=0))
            err += err_remain

        outputs = {
            'pred_masks': unordered_pred_masks,
            'gt_masks': gt_masks,
            'mag_mix': mag_mix,
            'mags': mags,
            'weight': weight
        }

        return err, outputs
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        feats = batch_data['feats']
        coords = batch_data['coords']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        FN = args.num_frames
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                if N > 2:
                    binary_masks = torch.ones_like(mags[n])
                    for m in range(N):
                        binary_masks *= (mags[n] >= mags[m])
                    gt_masks[n] = binary_masks
                else:
                    gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()

            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_vision
        feat_frames = [None for n in range(N)]
        for n in range(N):
            frames_features = []
            for f in range(FN):
                sin = ME.SparseTensor(
                    feats[n][f],
                    coords[n][f].int(),
                    allow_duplicate_coords=True)  #Create SparseTensor
                frames_features.append(self.net_vision.forward(sin))
            frames_features = torch.stack(frames_features)
            frames_features = frames_features.permute(1, 2, 0)
            if args.frame_pool == 'maxpool':
                feat_frame = F.adaptive_max_pool1d(frames_features, 1)
            elif args.frame_pool == 'avgpool':
                feat_frame = F.adaptive_avg_pool1d(frames_features, 1)
            feat_frames[n] = feat_frame.squeeze()
            feat_frames[n] = activate(feat_frames[n], args.vision_activation)

        # 3. sound synthesizer
        pred_masks = [None for n in range(N)]
        for n in range(N):
            pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # 4. loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def output_visuals(batch_data, outputs, args):

    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    features = batch_data['feats']
    coords = batch_data['coords']
    infos = batch_data['infos']

    pred_masks_ = outputs['pred_masks']
    gt_masks_ = outputs['gt_masks']
    mag_mix_ = outputs['mag_mix']
    weight_ = outputs['weight']

    # unwarp log scale
    N = args.num_mix
    FN = args.num_frames
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    gt_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         gt_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
            gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]
            gt_masks_linear[n] = gt_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    mag_mix_ = mag_mix_.detach().cpu().numpy()
    phase_mix = phase_mix.numpy()
    weight_ = weight_.detach().cpu().numpy()
    for n in range(N):
        pred_masks_[n] = pred_masks_[n].detach().cpu().numpy()
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()
        gt_masks_[n] = gt_masks_[n].detach().cpu().numpy()
        gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(
                np.float32)
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):

        # video names
        prefix = []
        for n in range(N):
            prefix.append('-'.join(
                infos[n][0][j].split('/')[-2:]).split('.')[0])
        prefix = '+'.join(prefix)
        makedirs(os.path.join(args.vis, prefix))

        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)
        mix_amp = magnitude2heatmap(mag_mix_[j, 0])
        weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.)
        filename_mixwav = os.path.join(prefix, 'mix.wav')
        filename_mixmag = os.path.join(prefix, 'mix.jpg')
        filename_weight = os.path.join(prefix, 'weight.jpg')
        imageio.imwrite(os.path.join(args.vis, filename_mixmag),
                        mix_amp[::-1, :, :])
        imageio.imwrite(os.path.join(args.vis, filename_weight),
                        weight[::-1, :])
        wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate,
                      mix_wav)

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # GT and predicted audio recovery
            gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0]
            gt_wav = istft_reconstruction(gt_mag,
                                          phase_mix[j, 0],
                                          hop_length=args.stft_hop)
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag,
                                                phase_mix[j, 0],
                                                hop_length=args.stft_hop)

            # output masks
            filename_gtmask = os.path.join(prefix,
                                           'gtmask{}.jpg'.format(n + 1))
            filename_predmask = os.path.join(prefix,
                                             'predmask{}.jpg'.format(n + 1))
            gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(
                np.uint8)
            imageio.imwrite(os.path.join(args.vis, filename_gtmask),
                            gt_mask[::-1, :])
            imageio.imwrite(os.path.join(args.vis, filename_predmask),
                            pred_mask[::-1, :])

            # ouput spectrogram (log of magnitude, show colormap)
            filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1))
            filename_predmag = os.path.join(prefix,
                                            'predamp{}.jpg'.format(n + 1))
            gt_mag = magnitude2heatmap(gt_mag)
            pred_mag = magnitude2heatmap(pred_mag)
            imageio.imwrite(os.path.join(args.vis, filename_gtmag),
                            gt_mag[::-1, :, :])
            imageio.imwrite(os.path.join(args.vis, filename_predmag),
                            pred_mag[::-1, :, :])

            # output audio
            filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1))
            filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1))
            wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate,
                          gt_wav)
            wavfile.write(os.path.join(args.vis, filename_predwav),
                          args.audRate, preds_wav[n])

            #output pointclouds
            for f in range(FN):
                idx = torch.where(coords[n][f][:, 0] == j)
                path_point = os.path.join(
                    args.vis, prefix,
                    'point{}_frame{}.ply'.format(n + 1, f + 1))
                if args.rgbs_feature:
                    colors = np.asarray(features[n][f][idx])
                    xyz = np.asarray(coords[n][f][idx][:, 1:4])
                    xyz = xyz * args.voxel_size
                    save_points(path_point, xyz, colors)
                else:
                    xyz = np.asarray(features[n][f][idx])
                    save_points(path_point, xyz)