def forward(self, batch_data, args): audio_mix = batch_data['audio_mix'] # B, audio_len audios = batch_data['audios'] # num_mix, B, audio_len frames = batch_data['frames'] # num_mix, B, xxx N = args.num_mix B = audio_mix.size(0) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer pred_audios = [None for n in range(N)] for n in range(N): # pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) # pred_masks[n] = activate(pred_masks[n], args.output_activation) pred_audios[n] = self.net_sound(audio_mix, feat_frames[n]) activate(pred_audios[n], args.sound_activation) # 4. loss err = self.crit(pred_audios, audios).reshape(1) # print("\"", self.crit([audio_mix, audio_mix], audios).item(), self.crit(audios, audios).item(), err.item(),"\"") return err, pred_audios # or masks
def forward(self, feat_frame, feat_sound, args): N = args.num_mix pred_mask = [None for n in range(N)] # appearance attention for n in range(N): pred_mask[n] = self.net_avol(feat_frame[n], feat_sound) pred_mask[n] = activate(pred_mask[n], args.output_activation) return pred_mask
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['appearance_imags'] clip_imgs = batch_data['clips_frames'] gt_masks = batch_data['masks'] mag_mix = mag_mix + 1e-10 N = args.num_mix weight = torch.ones_like(mag_mix) B = mag_mix.size(0) T = mag_mix.size(2) grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix.unsqueeze(1), grid_warp, align_corners=False).squeeze() mag_mix = torch.log(mag_mix).detach() if args.use_mel else mag_mix feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) feat_motion = [None for n in range(N)] for n in range(N): feat_motion[n] = self.net_motion(clip_imgs[n]) feat_sound = [None for n in range(N)] pred_masks = [None for n in range(N)] for n in range(N): feat_sound[n], _, _ = self.net_sound(mag_mix.to(args.device), feat_motion[n], feat_frames[n]) pred_masks[n] = activate(feat_sound[n], args.sound_activation) for n in range(N): grid_warp = torch.from_numpy(warpgrid(B, args.stft_frame // 2 + 1, T, warp=False)).to(args.device) pred_masks[n] = F.grid_sample(pred_masks[n].unsqueeze(1), grid_warp, align_corners=False).squeeze() err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, frame, args): N = args.num_mix # return appearance features and appearance embedding feat_frames = [None for n in range(N)] emb_frames = [None for n in range(N)] for n in range(N): feat_frames[n], emb_frames[n] = self.net_frame.forward_multiframe_feat_emb(frame[n], pool=True) emb_frames[n] = activate(emb_frames[n], args.img_activation) return feat_frames, emb_frames
def forward(self, mags, mag_mix, args): mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy( warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) return feat_sound, \ {'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) print(B) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) gt_masks[1] = torch.mul(gt_masks[3], 0.) gt_masks[3] = torch.mul(gt_masks[3], 0.) # LOG magnitude log_mag_mix = torch.log1p(mag_mix).detach() log_mag0 = torch.log1p(mags[0]).detach() log_mag1 = torch.log1p(mags[1]).detach() log_mag2 = torch.log1p(mags[2]).detach() #log_mag3 = torch.log1p(mags[3]).detach() #with torch.no_grad(): # grounding feat_sound_ground = self.net_sound_ground(log_mag_mix) feat_frames_ground = [None for n in range(N)] for n in range(N): feat_frames_ground[n] = self.net_frame_ground.forward_multiframe( frames[n]) # Grounding for sep g_sep = [None for n in range(N)] x = [None for n in range(N)] for n in range(N): g_sep[n] = self.net_grounding(feat_sound_ground, feat_frames_ground[n]) x[n] = torch.softmax(g_sep[n].clone(), dim=-1) # Grounding module #feat_frame = (feat_frames_ground[0] + feat_frames_ground[1]) * 0.5 g_pos = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_neg = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[0]) # Grounding for solo sound g_solo = [None for n in range(N)] g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[2]) g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[3]) for n in range(N): g_solo[n] = torch.softmax(g_solo[n], dim=-1) g = [ torch.softmax(g_pos, dim=-1), torch.softmax(g_neg, dim=-1), x, g_solo ] # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer masks = [None for n in range(N)] for n in range(N): masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) masks[n] = activate(masks[n], args.output_activation) # 4. adjusted masks with grounding confidence scores pred_masks = [None for n in range(N)] s1 = masks[1].size(2) s2 = masks[1].size(3) if args.testing: # pred_masks[0] = masks[0] # pred_masks[1] = masks[1] # pred_masks[2] = masks[2] # pred_masks[3] = masks[3] pred_masks[0] = torch.mul( tf_data(g_sep[0], B, s1, s2).round(), masks[0]) pred_masks[1] = torch.mul( tf_data(g_sep[1], B, s1, s2).round(), masks[1]) pred_masks[2] = torch.mul( tf_data(g_sep[2], B, s1, s2).round(), masks[2]) pred_masks[3] = torch.mul( tf_data(g_sep[3], B, s1, s2).round(), masks[3]) else: pred_masks[0] = torch.mul( tf_data(g_sep[0], B, s1, s2).round(), masks[0]) + torch.mul( tf_data(g_sep[1], B, s1, s2).round(), masks[1]) pred_masks[1] = masks[1] pred_masks[2] = torch.mul( tf_data(g_sep[2], B, s1, s2).round(), masks[2]) + torch.mul( tf_data(g_sep[3], B, s1, s2).round(), masks[3]) pred_masks[3] = masks[3] # pred_masks[0] = masks[0]#+masks[1] # pred_masks[1] = masks[1] # pred_masks[2] = masks[2]#+masks[3] # pred_masks[3] = masks[3] # 5. loss loss_sep = 0.5 * ( self.crit(pred_masks[0], gt_masks[0], weight).reshape(1) + self.crit(pred_masks[2], gt_masks[2], weight).reshape(1)) df_mask = [None for n in range(N)] for i in range(N): df_mask[i] = torch.zeros(B, 1).cuda() for j in range(B): df_mask[0][j] = self.crit(masks[0][j:j + 1], gt_masks[0][j:j + 1], weight[j:j + 1]) df_mask[1][j] = self.crit(masks[1][j:j + 1], gt_masks[0][j:j + 1], weight[j:j + 1]) df_mask[2][j] = self.crit(masks[2][j:j + 1], gt_masks[2][j:j + 1], weight[j:j + 1]) df_mask[3][j] = self.crit(masks[3][j:j + 1], gt_masks[2][j:j + 1], weight[j:j + 1]) a1 = (1 - df_mask[0].div(df_mask[1] + df_mask[0])) a2 = (1 - df_mask[2].div(df_mask[2] + df_mask[3])) p = torch.zeros(B).cuda() sep_pos = [None for n in range(N)] for i in range(N): sep_pos[i] = torch.zeros(B, 1).cuda() for i in range(N): for j in range(B): sep_pos[i][j] = self.cts(g_sep[i][j:j + 1], p[j:j + 1].long()) sep_neg = [None for n in range(N)] for i in range(N): sep_neg[i] = torch.zeros(B, 1).cuda() n = torch.ones(B).cuda() for i in range(N): for j in range(B): sep_neg[i][j] = self.cts(g_sep[i][j:j + 1], n[j:j + 1].long()) cts_pos = torch.zeros(B).cuda() cts_pos1 = torch.zeros(B).cuda() for i in range(B): cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long()) cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long()) loss_grd = (cts_pos * a1.round() + cts_pos1 * (1 - a1).round()).mean( ) + self.cts(g_neg, n.long( )) #torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long()) loss_grd_pos = (sep_pos[0] * a1.round() + sep_pos[1] * (1 - a1).round() + sep_pos[2] * a2.round() + sep_pos[3] * (1 - a2).round()).mean() th = 0.1 loss_grd_neg = (sep_neg[0] * (1 - a1).round() * ((1 + torch.sign(df_mask[0] - th)) / 2) + sep_neg[1] * a1.round() * ((1 + torch.sign(df_mask[1] - th)) / 2) + sep_neg[2] * (1 - a2).round() * ((1 + torch.sign(df_mask[2] - th)) / 2) + sep_neg[3] * a2.round() * ((1 + torch.sign(df_mask[3] - th)) / 2)).mean() loss_grd_sep = ( loss_grd_neg + loss_grd_pos ) * 0.25 #(torch.min(sep_pos[0], sep_pos[1]).mean() + torch.min(sep_pos[2], sep_pos[3]).mean())*0.5 loss = loss_sep + loss_grd_sep err = loss + loss_grd * 0.5 return err, loss_sep, g,\ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], args.output_activation) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] frame_emb = batch_data['frame_emb'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # separating sound sound_size = feat_sound.size() B, C = sound_size[0], sound_size[1] pred_masks = [None for n in range(N)] for n in range(N): feat_img = frame_emb[n].float() feat_img = feat_img.view(B, 1, C) pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \ .view(B, 1, *sound_size[2:]) pred_masks[n] = activate(pred_masks[n], args.output_activation) # loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def train(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader, optimizer, history, epoch, args): print('Training at {} epochs...'.format(epoch)) torch.set_grad_enabled(True) batch_time = AverageMeter() data_time = AverageMeter() # switch to train mode netWrapper1.train() netWrapper2.train() netWrapper3.train() # main loop torch.cuda.synchronize() tic = time.perf_counter() for i, batch_data in enumerate(loader): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] N = args.num_mix B = mag_mix.shape[0] for n in range(N): frames[n] = torch.autograd.Variable(frames[n]).to(args.device) mags[n] = torch.autograd.Variable(mags[n]).to(args.device) mag_mix = torch.autograd.Variable(mag_mix).to(args.device) # forward pass optimizer.zero_grad() # return feat_sound feat_sound, outputs = netWrapper1.forward(mags, mag_mix, args) gt_masks = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # return feat_frame, and emb_frame feat_frame, emb_frame = netWrapper2.forward(frames, args) # random select positive/negative pairs idx_pos = torch.randint(0,N, (1,)) idx_neg = N -1 -idx_pos # appearance attention masks = netWrapper3.forward(feat_frame, emb_frame[idx_pos], args) mask_pos = masks[idx_pos] mask_neg = masks[idx_neg] # max pooling pred_pos = F.adaptive_max_pool2d(mask_pos, 1) pred_pos = pred_pos.view(mask_pos.shape[0]) pred_neg = F.adaptive_max_pool2d(mask_neg, 1) pred_neg = pred_neg.view(mask_neg.shape[0]) # ground truth for the positive/negative pairs y1 = torch.ones(B,device=args.device).detach() y0 = torch.zeros(B, device=args.device).detach() # localization loss and acc loss_loc_pos = crit_loc(pred_pos, y1).reshape(1) loss_loc_neg = crit_loc(pred_neg, y0).reshape(1) loss_loc = args.lamda * (loss_loc_pos + loss_loc_neg)/N pred_pos = (pred_pos > args.mask_thres) pred_neg = (pred_neg > args.mask_thres) valacc = 0 for j in range(B): if pred_pos[j].item() == y1[j].item(): valacc += 1.0 if pred_neg[j].item() == y0[j].item(): valacc += 1.0 valacc = valacc/N/B # sepatate sounds (for simplicity, we don't use the alpha and beta) sound_size = feat_sound.size() B, C = sound_size[0], sound_size[1] pred_masks = [None for n in range(N)] for n in range(N): feat_img = emb_frame[n] feat_img = feat_img.view(B, 1, C) pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \ .view(B, 1, *sound_size[2:]) pred_masks[n] = activate(pred_masks[n], args.output_activation) # separation loss loss_sep = crit_sep(pred_masks, gt_masks, weight_).reshape(1) # total loss loss = loss_loc + loss_sep loss.backward() optimizer.step() # measure total time torch.cuda.synchronize() batch_time.update(time.perf_counter() - tic) tic = time.perf_counter() # display if i % args.disp_iter == 0: print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_sound: {}, lr_frame: {}, lr_avol: {}, ' 'loss: {:.5f}, loss_loc: {:.5f}, loss_sep: {:.5f}, acc: {:.5f} ' .format(epoch, i, args.epoch_iters, batch_time.average(), data_time.average(), args.lr_sound, args.lr_frame, args.lr_avol, loss.item(), loss_loc.item(), loss_sep.item(), valacc)) fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters history['train']['epoch'].append(fractional_epoch) history['train']['err'].append(loss.item()) history['train']['err_loc'].append(loss_loc.item()) history['train']['err_sep'].append(loss_sep.item()) history['train']['acc'].append(valacc)
def evaluate(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader, history, epoch, args): print('Evaluating at {} epochs...'.format(epoch)) torch.set_grad_enabled(False) # remove previous viz results makedirs(args.vis, remove=False) # switch to eval mode netWrapper1.eval() netWrapper2.eval() netWrapper3.eval() # initialize meters loss_meter = AverageMeter() loss_acc_meter = AverageMeter() loss_sep_meter = AverageMeter() loss_loc_meter = AverageMeter() sdr_mix_meter = AverageMeter() sdr_meter = AverageMeter() sir_meter = AverageMeter() sar_meter = AverageMeter() vis_rows = [] for i, batch_data in enumerate(loader): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] N = args.num_mix B = mag_mix.shape[0] for n in range(N): frames[n] = torch.autograd.Variable(frames[n]).to(args.device) mags[n] = torch.autograd.Variable(mags[n]).to(args.device) mag_mix = torch.autograd.Variable(mag_mix).to(args.device) # forward pass # return feat_sound feat_sound, outputs = netWrapper1.forward(mags, mag_mix, args) gt_masks = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # return feat_frame, and emb_frame feat_frame, emb_frame = netWrapper2.forward(frames, args) # random select positive/negative pairs idx_pos = torch.randint(0,N, (1,)) idx_neg = N -1 -idx_pos # appearance attention masks = netWrapper3.forward(feat_frame, emb_frame[idx_pos], args) mask_pos = masks[idx_pos] mask_neg = masks[idx_neg] # max pooling pred_pos = F.adaptive_max_pool2d(mask_pos, 1) pred_pos = pred_pos.view(mask_pos.shape[0]) pred_neg = F.adaptive_max_pool2d(mask_neg, 1) pred_neg = pred_neg.view(mask_neg.shape[0]) # ground truth for the positive/negative pairs y1 = torch.ones(B,device=args.device).detach() y0 = torch.zeros(B, device=args.device).detach() # localization loss loss_loc_pos = crit_loc(pred_pos, y1).reshape(1) loss_loc_neg = crit_loc(pred_neg, y0).reshape(1) loss_loc = args.lamda * (loss_loc_pos + loss_loc_neg)/N # Calculate val accuracy pred_pos = (pred_pos > args.mask_thres) pred_neg = (pred_neg > args.mask_thres) valacc = 0 for j in range(B): if pred_pos[j].item() == y1[j].item(): valacc += 1.0 if pred_neg[j].item() == y0[j].item(): valacc += 1.0 valacc = valacc/N/B # sepatate sounds sound_size = feat_sound.size() B, C = sound_size[0], sound_size[1] pred_masks = [None for n in range(N)] for n in range(N): feat_img = emb_frame[n] feat_img = feat_img.view(B, 1, C) pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \ .view(B, 1, *sound_size[2:]) pred_masks[n] = activate(pred_masks[n], args.output_activation) # separatioon loss loss_sep = crit_sep(pred_masks, gt_masks, weight_).reshape(1) # total loss loss = loss_loc + loss_sep loss_meter.update(loss.item()) loss_acc_meter.update(valacc) loss_sep_meter.update(loss_sep.item()) loss_loc_meter.update(loss_loc.item()) print('[Eval] iter {}, loss: {:.4f}, loss_loc: {:.4f}, loss_sep: {:.4f}, acc: {:.4f} '.format(i, loss.item(), loss_loc.item(), loss_sep.item(), valacc)) # calculate metrics sdr_mix, sdr, sir, sar = calc_metrics(batch_data, pred_masks, args) sdr_mix_meter.update(sdr_mix) sdr_meter.update(sdr) sir_meter.update(sir) sar_meter.update(sar) # output visualization if len(vis_rows) < args.num_vis: output_visuals_PosNeg(vis_rows, batch_data, mask_pos, mask_neg, idx_pos, idx_neg, pred_masks, gt_masks, mag_mix_, weight_, args) print('[Eval Summary] Epoch: {}, Loss: {:.4f}, Loss_loc: {:.4f}, Loss_sep: {:.4f}, acc: {:.4f}, sdr_mix: {:.4f}, sdr: {:.4f}, sir: {:.4f}, sar: {:.4f}, ' .format(epoch, loss_meter.average(), loss_loc_meter.average(), loss_sep_meter.average(), loss_acc_meter.average(), sdr_mix_meter.average(), sdr_meter.average(), sir_meter.average(), sar_meter.average())) history['val']['epoch'].append(epoch) history['val']['err'].append(loss_meter.average()) history['val']['err_loc'].append(loss_loc_meter.average()) history['val']['err_sep'].append(loss_sep_meter.average()) history['val']['acc'].append(loss_acc_meter.average()) history['val']['sdr'].append(sdr_meter.average()) history['val']['sir'].append(sir_meter.average()) history['val']['sar'].append(sar_meter.average()) # Plot figure if epoch > 0: print('Plotting figures...') plot_loss_loc_sep_acc_metrics(args.ckpt, history) print('this evaluation round is done!')
def forward(self, batch_data, args): ### prepare data mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 mag_mix_tmp = mag_mix.clone() N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! # Please notice that, gt_masks are unordered gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 2.) ### Minus part if 'Minus' not in self.mode: self.requires_grad(self.net_sound_M, False) self.requires_grad(self.net_frame_M, False) feat_frames = [None for n in range(N)] ordered_pred_masks = [None for n in range(N)] ordered_pred_mags = [None for n in range(N)] # Step1: obtain all the frame features # forward net_frame_M -> Bx1xC for n in range(N): log_mag_mix = torch.log(mag_mix) feat_frames[n] = self.net_frame_M.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # Step2: separate the sounds one by one # forward net_sound_M -> BxCxHxW if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, 256, warp=False)).to(args.device) index_record = [] for n in range(N): log_mag_mix = torch.log(mag_mix).detach() feat_sound = self.net_sound_M(log_mag_mix) _, C, H, W = feat_sound.shape feat_sound = feat_sound.view(B, C, -1) # obtain current separated sound energy_list = [] tmp_masks = [] tmp_pred_mags = [] for feat_frame in feat_frames: cur_pred_mask = torch.bmm(feat_frame.unsqueeze(1), feat_sound).view(B, 1, H, W) cur_pred_mask = activate(cur_pred_mask, args.output_activation) tmp_masks.append(cur_pred_mask) # Here we cut off the loss flow from Minus net to Plus net # in order to train more steadily if args.log_freq: cur_pred_mask_unwrap = F.grid_sample( cur_pred_mask.detach(), grid_unwarp) if args.binary_mask: cur_pred_mask_unwrap = (cur_pred_mask_unwrap > args.mask_thres).float() else: cur_pred_mask_unwrap = cur_pred_mask.detach() cur_pred_mag = cur_pred_mask_unwrap * mag_mix_tmp tmp_pred_mags.append(cur_pred_mag) energy_list.append( np.array(cur_pred_mag.view(B, -1).mean(dim=1).cpu().data)) total_energy = np.stack(energy_list, axis=1) # _, cur_index = torch.max(total_energy) cur_index = self.choose_max(index_record, total_energy) index_record.append(cur_index) masks = torch.stack(tmp_masks, dim=0) ordered_pred_masks[n] = masks[cur_index, list(range(B))] pred_mags = torch.stack(tmp_pred_mags, dim=0) ordered_pred_mags[n] = pred_mags[cur_index, list(range(B))] #log_mag_mix = log_mag_mix - log_mag_mix * pred_masks[n] mag_mix = mag_mix - mag_mix * ordered_pred_masks[n] + 1e-10 # just for swap pred_masks, in order to compute loss conveniently # since gt_masks are unordered, we must transfer ordered_pred_masks to unordered index_record = np.stack(index_record, axis=1) total_masks = torch.stack(ordered_pred_masks, dim=1) total_pred_mags = torch.stack(ordered_pred_mags, dim=1) unordered_pred_masks = [] unordered_pred_mags = [] for n in range(N): mask_index = np.where(index_record == n) if args.binary_mask: unordered_pred_masks.append(total_masks[mask_index]) unordered_pred_mags.append(total_pred_mags[mask_index]) else: unordered_pred_masks.append(total_masks[mask_index] * 2) unordered_pred_mags.append(total_pred_mags[mask_index]) ### Plus part if 'Plus' in self.mode: pre_sum = torch.zeros_like(unordered_pred_masks[0]).to(args.device) Plus_pred_masks = [] for n in range(N): unordered_pred_mag = unordered_pred_mags[n].log() unordered_pred_mag = F.grid_sample(unordered_pred_mag, grid_warp) input_concat = torch.cat((pre_sum, unordered_pred_mag), dim=1) residual_mask = activate(self.net_sound_P(input_concat), args.sound_activation) Plus_pred_masks.append(unordered_pred_masks[n] + residual_mask) pre_sum = pre_sum.sum(dim=1, keepdim=True).detach() unordered_pred_masks = Plus_pred_masks # loss if args.need_loss_ratio: err = 0 for n in range(N): err += self.crit(unordered_pred_masks[n], gt_masks[n], weight) / N * 2**(n - 1) else: err = self.crit(unordered_pred_masks, gt_masks, weight).reshape(1) if 'Minus' in self.mode: res_mag_mix = torch.exp(log_mag_mix) err_remain = torch.mean(weight * torch.clamp(res_mag_mix - 1e-2, min=0)) err += err_remain outputs = { 'pred_masks': unordered_pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight } return err, outputs
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] feats = batch_data['feats'] coords = batch_data['coords'] mag_mix = mag_mix + 1e-10 N = args.num_mix FN = args.num_frames B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: if N > 2: binary_masks = torch.ones_like(mags[n]) for m in range(N): binary_masks *= (mags[n] >= mags[m]) gt_masks[n] = binary_masks else: gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_vision feat_frames = [None for n in range(N)] for n in range(N): frames_features = [] for f in range(FN): sin = ME.SparseTensor( feats[n][f], coords[n][f].int(), allow_duplicate_coords=True) #Create SparseTensor frames_features.append(self.net_vision.forward(sin)) frames_features = torch.stack(frames_features) frames_features = frames_features.permute(1, 2, 0) if args.frame_pool == 'maxpool': feat_frame = F.adaptive_max_pool1d(frames_features, 1) elif args.frame_pool == 'avgpool': feat_frame = F.adaptive_avg_pool1d(frames_features, 1) feat_frames[n] = feat_frame.squeeze() feat_frames[n] = activate(feat_frames[n], args.vision_activation) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], args.output_activation) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}