Пример #1
0
    def forward(self, batch_data, args):
        audio_mix = batch_data['audio_mix']  # B, audio_len
        audios = batch_data['audios']  # num_mix, B, audio_len
        frames = batch_data['frames']  # num_mix, B, xxx

        N = args.num_mix
        B = audio_mix.size(0)

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # 3. sound synthesizer
        pred_audios = [None for n in range(N)]
        for n in range(N):
            #     pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            #     pred_masks[n] = activate(pred_masks[n], args.output_activation)
            pred_audios[n] = self.net_sound(audio_mix, feat_frames[n])
            activate(pred_audios[n], args.sound_activation)

        # 4. loss
        err = self.crit(pred_audios, audios).reshape(1)
        # print("\"", self.crit([audio_mix, audio_mix], audios).item(), self.crit(audios, audios).item(), err.item(),"\"")

        return err, pred_audios  # or masks
    def forward(self, feat_frame, feat_sound, args):
        N = args.num_mix

        pred_mask = [None for n in range(N)]
        # appearance attention
        for n in range(N):
            pred_mask[n] = self.net_avol(feat_frame[n], feat_sound)
            pred_mask[n] = activate(pred_mask[n], args.output_activation)

        return pred_mask
Пример #3
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['appearance_imags']
        clip_imgs = batch_data['clips_frames']
        gt_masks = batch_data['masks']
        mag_mix = mag_mix + 1e-10
        N = args.num_mix
        weight = torch.ones_like(mag_mix)

        B = mag_mix.size(0)
        T = mag_mix.size(2)
        grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device)
        mag_mix = F.grid_sample(mag_mix.unsqueeze(1), grid_warp, align_corners=False).squeeze()

        mag_mix = torch.log(mag_mix).detach() if args.use_mel else mag_mix

        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        feat_motion = [None for n in range(N)]
        for n in range(N):
            feat_motion[n] = self.net_motion(clip_imgs[n])

        feat_sound = [None for n in range(N)]
        pred_masks = [None for n in range(N)]

        for n in range(N):
            feat_sound[n], _, _ = self.net_sound(mag_mix.to(args.device), feat_motion[n], feat_frames[n])
            pred_masks[n] = activate(feat_sound[n], args.sound_activation)
        for n in range(N):
            grid_warp = torch.from_numpy(warpgrid(B, args.stft_frame // 2 + 1, T, warp=False)).to(args.device)
            pred_masks[n] = F.grid_sample(pred_masks[n].unsqueeze(1), grid_warp, align_corners=False).squeeze()

        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
               {'pred_masks': pred_masks, 'gt_masks': gt_masks,
                'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
    def forward(self, frame, args):
        
        N = args.num_mix

        # return appearance features and appearance embedding
        feat_frames = [None for n in range(N)]
        emb_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n], emb_frames[n] = self.net_frame.forward_multiframe_feat_emb(frame[n], pool=True)
            emb_frames[n] = activate(emb_frames[n], args.img_activation)
        
        return feat_frames, emb_frames
    def forward(self, mags, mag_mix, args):
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(
                warpgrid(B, 256, T, warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        return feat_sound, \
            {'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
Пример #6
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        print(B)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        gt_masks[1] = torch.mul(gt_masks[3], 0.)
        gt_masks[3] = torch.mul(gt_masks[3], 0.)
        # LOG magnitude
        log_mag_mix = torch.log1p(mag_mix).detach()
        log_mag0 = torch.log1p(mags[0]).detach()
        log_mag1 = torch.log1p(mags[1]).detach()
        log_mag2 = torch.log1p(mags[2]).detach()
        #log_mag3 = torch.log1p(mags[3]).detach()
        #with torch.no_grad():
        # grounding
        feat_sound_ground = self.net_sound_ground(log_mag_mix)

        feat_frames_ground = [None for n in range(N)]
        for n in range(N):
            feat_frames_ground[n] = self.net_frame_ground.forward_multiframe(
                frames[n])

        # Grounding for sep
        g_sep = [None for n in range(N)]
        x = [None for n in range(N)]
        for n in range(N):
            g_sep[n] = self.net_grounding(feat_sound_ground,
                                          feat_frames_ground[n])
            x[n] = torch.softmax(g_sep[n].clone(), dim=-1)

        # Grounding module
        #feat_frame = (feat_frames_ground[0] + feat_frames_ground[1]) * 0.5
        g_pos = self.net_grounding(self.net_sound_ground(log_mag0),
                                   feat_frames_ground[0])
        g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0),
                                    feat_frames_ground[1])
        g_neg = self.net_grounding(self.net_sound_ground(log_mag2),
                                   feat_frames_ground[0])

        # Grounding for solo sound
        g_solo = [None for n in range(N)]
        g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[0])
        g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0),
                                       feat_frames_ground[1])
        g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[2])
        g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2),
                                       feat_frames_ground[3])
        for n in range(N):
            g_solo[n] = torch.softmax(g_solo[n], dim=-1)
        g = [
            torch.softmax(g_pos, dim=-1),
            torch.softmax(g_neg, dim=-1), x, g_solo
        ]

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # 3. sound synthesizer
        masks = [None for n in range(N)]
        for n in range(N):
            masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            masks[n] = activate(masks[n], args.output_activation)

        # 4. adjusted masks with grounding confidence scores
        pred_masks = [None for n in range(N)]

        s1 = masks[1].size(2)
        s2 = masks[1].size(3)

        if args.testing:
            # pred_masks[0] = masks[0]
            # pred_masks[1] = masks[1]
            # pred_masks[2] = masks[2]
            # pred_masks[3] = masks[3]
            pred_masks[0] = torch.mul(
                tf_data(g_sep[0], B, s1, s2).round(), masks[0])
            pred_masks[1] = torch.mul(
                tf_data(g_sep[1], B, s1, s2).round(), masks[1])
            pred_masks[2] = torch.mul(
                tf_data(g_sep[2], B, s1, s2).round(), masks[2])
            pred_masks[3] = torch.mul(
                tf_data(g_sep[3], B, s1, s2).round(), masks[3])
        else:
            pred_masks[0] = torch.mul(
                tf_data(g_sep[0], B, s1, s2).round(), masks[0]) + torch.mul(
                    tf_data(g_sep[1], B, s1, s2).round(), masks[1])
            pred_masks[1] = masks[1]
            pred_masks[2] = torch.mul(
                tf_data(g_sep[2], B, s1, s2).round(), masks[2]) + torch.mul(
                    tf_data(g_sep[3], B, s1, s2).round(), masks[3])
            pred_masks[3] = masks[3]
        # pred_masks[0] = masks[0]#+masks[1]
        # pred_masks[1] = masks[1]
        # pred_masks[2] = masks[2]#+masks[3]
        # pred_masks[3] = masks[3]

        # 5. loss
        loss_sep = 0.5 * (
            self.crit(pred_masks[0], gt_masks[0], weight).reshape(1) +
            self.crit(pred_masks[2], gt_masks[2], weight).reshape(1))

        df_mask = [None for n in range(N)]
        for i in range(N):
            df_mask[i] = torch.zeros(B, 1).cuda()

        for j in range(B):
            df_mask[0][j] = self.crit(masks[0][j:j + 1], gt_masks[0][j:j + 1],
                                      weight[j:j + 1])
            df_mask[1][j] = self.crit(masks[1][j:j + 1], gt_masks[0][j:j + 1],
                                      weight[j:j + 1])
            df_mask[2][j] = self.crit(masks[2][j:j + 1], gt_masks[2][j:j + 1],
                                      weight[j:j + 1])
            df_mask[3][j] = self.crit(masks[3][j:j + 1], gt_masks[2][j:j + 1],
                                      weight[j:j + 1])

        a1 = (1 - df_mask[0].div(df_mask[1] + df_mask[0]))
        a2 = (1 - df_mask[2].div(df_mask[2] + df_mask[3]))

        p = torch.zeros(B).cuda()
        sep_pos = [None for n in range(N)]
        for i in range(N):
            sep_pos[i] = torch.zeros(B, 1).cuda()

        for i in range(N):
            for j in range(B):
                sep_pos[i][j] = self.cts(g_sep[i][j:j + 1], p[j:j + 1].long())

        sep_neg = [None for n in range(N)]
        for i in range(N):
            sep_neg[i] = torch.zeros(B, 1).cuda()

        n = torch.ones(B).cuda()
        for i in range(N):
            for j in range(B):
                sep_neg[i][j] = self.cts(g_sep[i][j:j + 1], n[j:j + 1].long())

        cts_pos = torch.zeros(B).cuda()
        cts_pos1 = torch.zeros(B).cuda()
        for i in range(B):
            cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long())
            cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long())

        loss_grd = (cts_pos * a1.round() + cts_pos1 * (1 - a1).round()).mean(
        ) + self.cts(g_neg, n.long(
        ))  #torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long())

        loss_grd_pos = (sep_pos[0] * a1.round() + sep_pos[1] *
                        (1 - a1).round() + sep_pos[2] * a2.round() +
                        sep_pos[3] * (1 - a2).round()).mean()
        th = 0.1
        loss_grd_neg = (sep_neg[0] * (1 - a1).round() *
                        ((1 + torch.sign(df_mask[0] - th)) / 2) +
                        sep_neg[1] * a1.round() *
                        ((1 + torch.sign(df_mask[1] - th)) / 2) + sep_neg[2] *
                        (1 - a2).round() *
                        ((1 + torch.sign(df_mask[2] - th)) / 2) +
                        sep_neg[3] * a2.round() *
                        ((1 + torch.sign(df_mask[3] - th)) / 2)).mean()

        loss_grd_sep = (
            loss_grd_neg + loss_grd_pos
        ) * 0.25  #(torch.min(sep_pos[0], sep_pos[1]).mean() + torch.min(sep_pos[2], sep_pos[3]).mean())*0.5
        loss = loss_sep + loss_grd_sep

        err = loss + loss_grd * 0.5

        return err, loss_sep, g,\
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
Пример #7
0
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # 3. sound synthesizer
        pred_masks = [None for n in range(N)]
        for n in range(N):
            pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # 4. loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        frame_emb = batch_data['frame_emb']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # separating sound
        sound_size = feat_sound.size()
        B, C = sound_size[0], sound_size[1]
        pred_masks = [None for n in range(N)]
        for n in range(N):
            feat_img = frame_emb[n].float()
            feat_img = feat_img.view(B, 1, C)
            pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \
                .view(B, 1, *sound_size[2:])
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def train(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader, optimizer, history, epoch, args):
    print('Training at {} epochs...'.format(epoch))
    torch.set_grad_enabled(True)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    # switch to train mode
    netWrapper1.train()
    netWrapper2.train()
    netWrapper3.train()

    # main loop
    torch.cuda.synchronize()
    tic = time.perf_counter()
    for i, batch_data in enumerate(loader):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']

        N = args.num_mix
        B = mag_mix.shape[0]
        for n in range(N):
            frames[n] = torch.autograd.Variable(frames[n]).to(args.device)
            mags[n] = torch.autograd.Variable(mags[n]).to(args.device)
        mag_mix = torch.autograd.Variable(mag_mix).to(args.device)

        # forward pass
        optimizer.zero_grad()
        # return feat_sound
        feat_sound, outputs = netWrapper1.forward(mags, mag_mix, args)
        gt_masks = outputs['gt_masks']
        mag_mix_ = outputs['mag_mix']
        weight_ = outputs['weight']

        # return feat_frame, and emb_frame
        feat_frame, emb_frame = netWrapper2.forward(frames, args)
        
        # random select positive/negative pairs
        idx_pos = torch.randint(0,N, (1,))
        idx_neg = N -1 -idx_pos
        # appearance attention
        masks = netWrapper3.forward(feat_frame, emb_frame[idx_pos], args)
        mask_pos = masks[idx_pos]
        mask_neg = masks[idx_neg]

        # max pooling
        pred_pos = F.adaptive_max_pool2d(mask_pos, 1)
        pred_pos = pred_pos.view(mask_pos.shape[0])
        pred_neg = F.adaptive_max_pool2d(mask_neg, 1)
        pred_neg = pred_neg.view(mask_neg.shape[0])

        # ground truth for the positive/negative pairs
        y1 = torch.ones(B,device=args.device).detach()
        y0 = torch.zeros(B, device=args.device).detach()

        # localization loss and acc
        loss_loc_pos = crit_loc(pred_pos, y1).reshape(1)
        loss_loc_neg = crit_loc(pred_neg, y0).reshape(1)
        loss_loc = args.lamda * (loss_loc_pos + loss_loc_neg)/N
        pred_pos = (pred_pos > args.mask_thres)
        pred_neg = (pred_neg > args.mask_thres)
        valacc = 0
        for j in range(B):
            if pred_pos[j].item() == y1[j].item():
                valacc += 1.0
            if pred_neg[j].item() == y0[j].item():
                valacc += 1.0
        valacc = valacc/N/B

        # sepatate sounds (for simplicity, we don't use the alpha and beta)
        sound_size = feat_sound.size()
        B, C = sound_size[0], sound_size[1]
        pred_masks = [None for n in range(N)]
        for n in range(N):
            feat_img = emb_frame[n]
            feat_img = feat_img.view(B, 1, C)
            pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \
                .view(B, 1, *sound_size[2:])
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # separation loss
        loss_sep = crit_sep(pred_masks, gt_masks, weight_).reshape(1)
        
        # total loss
        loss = loss_loc + loss_sep
        loss.backward()
        optimizer.step()

        # measure total time
        torch.cuda.synchronize()
        batch_time.update(time.perf_counter() - tic)
        tic = time.perf_counter()

        # display
        if i % args.disp_iter == 0:
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_sound: {}, lr_frame: {}, lr_avol: {}, '
                  'loss: {:.5f}, loss_loc: {:.5f}, loss_sep: {:.5f}, acc: {:.5f} '
                  .format(epoch, i, args.epoch_iters,
                          batch_time.average(), data_time.average(),
                          args.lr_sound, args.lr_frame, args.lr_avol,
                          loss.item(), loss_loc.item(), loss_sep.item(), 
                          valacc))
            fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
            history['train']['epoch'].append(fractional_epoch)
            history['train']['err'].append(loss.item())
            history['train']['err_loc'].append(loss_loc.item())
            history['train']['err_sep'].append(loss_sep.item())
            history['train']['acc'].append(valacc)
def evaluate(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    torch.set_grad_enabled(False)

    # remove previous viz results
    makedirs(args.vis, remove=False)

    # switch to eval mode
    netWrapper1.eval()
    netWrapper2.eval()
    netWrapper3.eval()

    # initialize meters
    loss_meter = AverageMeter()
    loss_acc_meter = AverageMeter()
    loss_sep_meter = AverageMeter()
    loss_loc_meter = AverageMeter()
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()
    
    vis_rows = []
    for i, batch_data in enumerate(loader):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']

        N = args.num_mix
        B = mag_mix.shape[0]
        
        for n in range(N):
            frames[n] = torch.autograd.Variable(frames[n]).to(args.device)
            mags[n] = torch.autograd.Variable(mags[n]).to(args.device)
        mag_mix = torch.autograd.Variable(mag_mix).to(args.device)
            
        # forward pass
        # return feat_sound
        feat_sound, outputs = netWrapper1.forward(mags, mag_mix, args)
        gt_masks = outputs['gt_masks']
        mag_mix_ = outputs['mag_mix']
        weight_ = outputs['weight']
        
        # return feat_frame, and emb_frame
        feat_frame, emb_frame = netWrapper2.forward(frames, args)

        # random select positive/negative pairs
        idx_pos = torch.randint(0,N, (1,))
        idx_neg = N -1 -idx_pos

        # appearance attention
        masks = netWrapper3.forward(feat_frame, emb_frame[idx_pos], args)
        mask_pos = masks[idx_pos]
        mask_neg = masks[idx_neg]

        # max pooling
        pred_pos = F.adaptive_max_pool2d(mask_pos, 1)
        pred_pos = pred_pos.view(mask_pos.shape[0])
        pred_neg = F.adaptive_max_pool2d(mask_neg, 1)
        pred_neg = pred_neg.view(mask_neg.shape[0])

        # ground truth for the positive/negative pairs
        y1 = torch.ones(B,device=args.device).detach()
        y0 = torch.zeros(B, device=args.device).detach()

        # localization loss
        loss_loc_pos = crit_loc(pred_pos, y1).reshape(1)
        loss_loc_neg = crit_loc(pred_neg, y0).reshape(1)
        loss_loc = args.lamda * (loss_loc_pos + loss_loc_neg)/N

        # Calculate val accuracy
        pred_pos = (pred_pos > args.mask_thres)
        pred_neg = (pred_neg > args.mask_thres)
        valacc = 0
        for j in range(B):
            if pred_pos[j].item() == y1[j].item():
                valacc += 1.0
            if pred_neg[j].item() == y0[j].item():
                valacc += 1.0
        valacc = valacc/N/B

        # sepatate sounds
        sound_size = feat_sound.size()
        B, C = sound_size[0], sound_size[1]

        pred_masks = [None for n in range(N)]
        for n in range(N):
            feat_img = emb_frame[n]
            feat_img = feat_img.view(B, 1, C)
            pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \
                .view(B, 1, *sound_size[2:])
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # separatioon loss
        loss_sep = crit_sep(pred_masks, gt_masks, weight_).reshape(1)

        # total loss
        loss = loss_loc + loss_sep

        loss_meter.update(loss.item())
        loss_acc_meter.update(valacc)
        loss_sep_meter.update(loss_sep.item())
        loss_loc_meter.update(loss_loc.item())

        print('[Eval] iter {}, loss: {:.4f}, loss_loc: {:.4f}, loss_sep: {:.4f}, acc: {:.4f} '.format(i, loss.item(), loss_loc.item(), loss_sep.item(),  valacc))

        # calculate metrics
        sdr_mix, sdr, sir, sar = calc_metrics(batch_data, pred_masks, args)
        sdr_mix_meter.update(sdr_mix)
        sdr_meter.update(sdr)
        sir_meter.update(sir)
        sar_meter.update(sar)

        # output visualization
        if len(vis_rows) < args.num_vis:
            output_visuals_PosNeg(vis_rows, batch_data, mask_pos, mask_neg, idx_pos, idx_neg, pred_masks, gt_masks, mag_mix_, weight_, args)

    print('[Eval Summary] Epoch: {}, Loss: {:.4f}, Loss_loc: {:.4f},  Loss_sep: {:.4f},  acc: {:.4f}, sdr_mix: {:.4f}, sdr: {:.4f}, sir: {:.4f}, sar: {:.4f}, '
           .format(epoch, loss_meter.average(), loss_loc_meter.average(), loss_sep_meter.average(), loss_acc_meter.average(), sdr_mix_meter.average(), sdr_meter.average(), sir_meter.average(), sar_meter.average()))
    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['err_loc'].append(loss_loc_meter.average())
    history['val']['err_sep'].append(loss_sep_meter.average())
    history['val']['acc'].append(loss_acc_meter.average())
    history['val']['sdr'].append(sdr_meter.average())
    history['val']['sir'].append(sir_meter.average())
    history['val']['sar'].append(sar_meter.average())

    # Plot figure
    if epoch > 0:
        print('Plotting figures...')
        plot_loss_loc_sep_acc_metrics(args.ckpt, history)
    print('this evaluation round is done!')
Пример #11
0
    def forward(self, batch_data, args):
        ### prepare data
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10
        mag_mix_tmp = mag_mix.clone()

        N = args.num_mix
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        # Please notice that, gt_masks are unordered
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 2.)

        ### Minus part
        if 'Minus' not in self.mode:
            self.requires_grad(self.net_sound_M, False)
            self.requires_grad(self.net_frame_M, False)

        feat_frames = [None for n in range(N)]
        ordered_pred_masks = [None for n in range(N)]
        ordered_pred_mags = [None for n in range(N)]

        # Step1: obtain all the frame features
        # forward net_frame_M -> Bx1xC
        for n in range(N):
            log_mag_mix = torch.log(mag_mix)
            feat_frames[n] = self.net_frame_M.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n], args.img_activation)

        # Step2: separate the sounds one by one
        # forward net_sound_M -> BxCxHxW
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B, args.stft_frame // 2 + 1, 256,
                         warp=False)).to(args.device)
        index_record = []
        for n in range(N):
            log_mag_mix = torch.log(mag_mix).detach()
            feat_sound = self.net_sound_M(log_mag_mix)
            _, C, H, W = feat_sound.shape
            feat_sound = feat_sound.view(B, C, -1)

            # obtain current separated sound
            energy_list = []
            tmp_masks = []
            tmp_pred_mags = []

            for feat_frame in feat_frames:
                cur_pred_mask = torch.bmm(feat_frame.unsqueeze(1),
                                          feat_sound).view(B, 1, H, W)
                cur_pred_mask = activate(cur_pred_mask, args.output_activation)
                tmp_masks.append(cur_pred_mask)
                # Here we cut off the loss flow from Minus net to Plus net
                # in order to train more steadily
                if args.log_freq:
                    cur_pred_mask_unwrap = F.grid_sample(
                        cur_pred_mask.detach(), grid_unwarp)
                    if args.binary_mask:
                        cur_pred_mask_unwrap = (cur_pred_mask_unwrap >
                                                args.mask_thres).float()
                else:
                    cur_pred_mask_unwrap = cur_pred_mask.detach()
                cur_pred_mag = cur_pred_mask_unwrap * mag_mix_tmp
                tmp_pred_mags.append(cur_pred_mag)
                energy_list.append(
                    np.array(cur_pred_mag.view(B, -1).mean(dim=1).cpu().data))

            total_energy = np.stack(energy_list, axis=1)
            # _, cur_index = torch.max(total_energy)
            cur_index = self.choose_max(index_record, total_energy)
            index_record.append(cur_index)

            masks = torch.stack(tmp_masks, dim=0)
            ordered_pred_masks[n] = masks[cur_index, list(range(B))]
            pred_mags = torch.stack(tmp_pred_mags, dim=0)
            ordered_pred_mags[n] = pred_mags[cur_index, list(range(B))]

            #log_mag_mix = log_mag_mix - log_mag_mix * pred_masks[n]
            mag_mix = mag_mix - mag_mix * ordered_pred_masks[n] + 1e-10

        # just for swap pred_masks, in order to compute loss conveniently
        # since gt_masks are unordered, we must transfer ordered_pred_masks to unordered
        index_record = np.stack(index_record, axis=1)
        total_masks = torch.stack(ordered_pred_masks, dim=1)
        total_pred_mags = torch.stack(ordered_pred_mags, dim=1)
        unordered_pred_masks = []
        unordered_pred_mags = []
        for n in range(N):
            mask_index = np.where(index_record == n)
            if args.binary_mask:
                unordered_pred_masks.append(total_masks[mask_index])
                unordered_pred_mags.append(total_pred_mags[mask_index])
            else:
                unordered_pred_masks.append(total_masks[mask_index] * 2)
                unordered_pred_mags.append(total_pred_mags[mask_index])

        ### Plus part
        if 'Plus' in self.mode:
            pre_sum = torch.zeros_like(unordered_pred_masks[0]).to(args.device)
            Plus_pred_masks = []
            for n in range(N):
                unordered_pred_mag = unordered_pred_mags[n].log()
                unordered_pred_mag = F.grid_sample(unordered_pred_mag,
                                                   grid_warp)
                input_concat = torch.cat((pre_sum, unordered_pred_mag), dim=1)

                residual_mask = activate(self.net_sound_P(input_concat),
                                         args.sound_activation)
                Plus_pred_masks.append(unordered_pred_masks[n] + residual_mask)

                pre_sum = pre_sum.sum(dim=1, keepdim=True).detach()

            unordered_pred_masks = Plus_pred_masks

        # loss
        if args.need_loss_ratio:
            err = 0
            for n in range(N):
                err += self.crit(unordered_pred_masks[n], gt_masks[n],
                                 weight) / N * 2**(n - 1)
        else:
            err = self.crit(unordered_pred_masks, gt_masks, weight).reshape(1)

        if 'Minus' in self.mode:
            res_mag_mix = torch.exp(log_mag_mix)
            err_remain = torch.mean(weight *
                                    torch.clamp(res_mag_mix - 1e-2, min=0))
            err += err_remain

        outputs = {
            'pred_masks': unordered_pred_masks,
            'gt_masks': gt_masks,
            'mag_mix': mag_mix,
            'mags': mags,
            'weight': weight
        }

        return err, outputs
    def forward(self, batch_data, args):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        feats = batch_data['feats']
        coords = batch_data['coords']
        mag_mix = mag_mix + 1e-10

        N = args.num_mix
        FN = args.num_frames
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if args.log_freq:
            grid_warp = torch.from_numpy(warpgrid(B, 256, T,
                                                  warp=True)).to(args.device)
            mag_mix = F.grid_sample(mag_mix, grid_warp)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if args.weighted_loss:
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warpping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if args.binary_mask:
                if N > 2:
                    binary_masks = torch.ones_like(mags[n])
                    for m in range(N):
                        binary_masks *= (mags[n] >= mags[m])
                    gt_masks[n] = binary_masks
                else:
                    gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()

            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, args.sound_activation)

        # 2. forward net_vision
        feat_frames = [None for n in range(N)]
        for n in range(N):
            frames_features = []
            for f in range(FN):
                sin = ME.SparseTensor(
                    feats[n][f],
                    coords[n][f].int(),
                    allow_duplicate_coords=True)  #Create SparseTensor
                frames_features.append(self.net_vision.forward(sin))
            frames_features = torch.stack(frames_features)
            frames_features = frames_features.permute(1, 2, 0)
            if args.frame_pool == 'maxpool':
                feat_frame = F.adaptive_max_pool1d(frames_features, 1)
            elif args.frame_pool == 'avgpool':
                feat_frame = F.adaptive_avg_pool1d(frames_features, 1)
            feat_frames[n] = feat_frame.squeeze()
            feat_frames[n] = activate(feat_frames[n], args.vision_activation)

        # 3. sound synthesizer
        pred_masks = [None for n in range(N)]
        for n in range(N):
            pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            pred_masks[n] = activate(pred_masks[n], args.output_activation)

        # 4. loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}