def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['appearance_imags'] clip_imgs = batch_data['clips_frames'] gt_masks = batch_data['masks'] mag_mix = mag_mix + 1e-10 N = args.num_mix weight = torch.ones_like(mag_mix) B = mag_mix.size(0) T = mag_mix.size(2) grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix.unsqueeze(1), grid_warp, align_corners=False).squeeze() mag_mix = torch.log(mag_mix).detach() if args.use_mel else mag_mix feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) feat_motion = [None for n in range(N)] for n in range(N): feat_motion[n] = self.net_motion(clip_imgs[n]) feat_sound = [None for n in range(N)] pred_masks = [None for n in range(N)] for n in range(N): feat_sound[n], _, _ = self.net_sound(mag_mix.to(args.device), feat_motion[n], feat_frames[n]) pred_masks[n] = activate(feat_sound[n], args.sound_activation) for n in range(N): grid_warp = torch.from_numpy(warpgrid(B, args.stft_frame // 2 + 1, T, warp=False)).to(args.device) pred_masks[n] = F.grid_sample(pred_masks[n].unsqueeze(1), grid_warp, align_corners=False).squeeze() err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, mags, mag_mix, args): mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy( warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) return feat_sound, \ {'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) print(B) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) gt_masks[1] = torch.mul(gt_masks[3], 0.) gt_masks[3] = torch.mul(gt_masks[3], 0.) # LOG magnitude log_mag_mix = torch.log1p(mag_mix).detach() log_mag0 = torch.log1p(mags[0]).detach() log_mag1 = torch.log1p(mags[1]).detach() log_mag2 = torch.log1p(mags[2]).detach() #log_mag3 = torch.log1p(mags[3]).detach() #with torch.no_grad(): # grounding feat_sound_ground = self.net_sound_ground(log_mag_mix) feat_frames_ground = [None for n in range(N)] for n in range(N): feat_frames_ground[n] = self.net_frame_ground.forward_multiframe( frames[n]) # Grounding for sep g_sep = [None for n in range(N)] x = [None for n in range(N)] for n in range(N): g_sep[n] = self.net_grounding(feat_sound_ground, feat_frames_ground[n]) x[n] = torch.softmax(g_sep[n].clone(), dim=-1) # Grounding module #feat_frame = (feat_frames_ground[0] + feat_frames_ground[1]) * 0.5 g_pos = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_neg = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[0]) # Grounding for solo sound g_solo = [None for n in range(N)] g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[2]) g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[3]) for n in range(N): g_solo[n] = torch.softmax(g_solo[n], dim=-1) g = [ torch.softmax(g_pos, dim=-1), torch.softmax(g_neg, dim=-1), x, g_solo ] # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer masks = [None for n in range(N)] for n in range(N): masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) masks[n] = activate(masks[n], args.output_activation) # 4. adjusted masks with grounding confidence scores pred_masks = [None for n in range(N)] s1 = masks[1].size(2) s2 = masks[1].size(3) if args.testing: # pred_masks[0] = masks[0] # pred_masks[1] = masks[1] # pred_masks[2] = masks[2] # pred_masks[3] = masks[3] pred_masks[0] = torch.mul( tf_data(g_sep[0], B, s1, s2).round(), masks[0]) pred_masks[1] = torch.mul( tf_data(g_sep[1], B, s1, s2).round(), masks[1]) pred_masks[2] = torch.mul( tf_data(g_sep[2], B, s1, s2).round(), masks[2]) pred_masks[3] = torch.mul( tf_data(g_sep[3], B, s1, s2).round(), masks[3]) else: pred_masks[0] = torch.mul( tf_data(g_sep[0], B, s1, s2).round(), masks[0]) + torch.mul( tf_data(g_sep[1], B, s1, s2).round(), masks[1]) pred_masks[1] = masks[1] pred_masks[2] = torch.mul( tf_data(g_sep[2], B, s1, s2).round(), masks[2]) + torch.mul( tf_data(g_sep[3], B, s1, s2).round(), masks[3]) pred_masks[3] = masks[3] # pred_masks[0] = masks[0]#+masks[1] # pred_masks[1] = masks[1] # pred_masks[2] = masks[2]#+masks[3] # pred_masks[3] = masks[3] # 5. loss loss_sep = 0.5 * ( self.crit(pred_masks[0], gt_masks[0], weight).reshape(1) + self.crit(pred_masks[2], gt_masks[2], weight).reshape(1)) df_mask = [None for n in range(N)] for i in range(N): df_mask[i] = torch.zeros(B, 1).cuda() for j in range(B): df_mask[0][j] = self.crit(masks[0][j:j + 1], gt_masks[0][j:j + 1], weight[j:j + 1]) df_mask[1][j] = self.crit(masks[1][j:j + 1], gt_masks[0][j:j + 1], weight[j:j + 1]) df_mask[2][j] = self.crit(masks[2][j:j + 1], gt_masks[2][j:j + 1], weight[j:j + 1]) df_mask[3][j] = self.crit(masks[3][j:j + 1], gt_masks[2][j:j + 1], weight[j:j + 1]) a1 = (1 - df_mask[0].div(df_mask[1] + df_mask[0])) a2 = (1 - df_mask[2].div(df_mask[2] + df_mask[3])) p = torch.zeros(B).cuda() sep_pos = [None for n in range(N)] for i in range(N): sep_pos[i] = torch.zeros(B, 1).cuda() for i in range(N): for j in range(B): sep_pos[i][j] = self.cts(g_sep[i][j:j + 1], p[j:j + 1].long()) sep_neg = [None for n in range(N)] for i in range(N): sep_neg[i] = torch.zeros(B, 1).cuda() n = torch.ones(B).cuda() for i in range(N): for j in range(B): sep_neg[i][j] = self.cts(g_sep[i][j:j + 1], n[j:j + 1].long()) cts_pos = torch.zeros(B).cuda() cts_pos1 = torch.zeros(B).cuda() for i in range(B): cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long()) cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long()) loss_grd = (cts_pos * a1.round() + cts_pos1 * (1 - a1).round()).mean( ) + self.cts(g_neg, n.long( )) #torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long()) loss_grd_pos = (sep_pos[0] * a1.round() + sep_pos[1] * (1 - a1).round() + sep_pos[2] * a2.round() + sep_pos[3] * (1 - a2).round()).mean() th = 0.1 loss_grd_neg = (sep_neg[0] * (1 - a1).round() * ((1 + torch.sign(df_mask[0] - th)) / 2) + sep_neg[1] * a1.round() * ((1 + torch.sign(df_mask[1] - th)) / 2) + sep_neg[2] * (1 - a2).round() * ((1 + torch.sign(df_mask[2] - th)) / 2) + sep_neg[3] * a2.round() * ((1 + torch.sign(df_mask[3] - th)) / 2)).mean() loss_grd_sep = ( loss_grd_neg + loss_grd_pos ) * 0.25 #(torch.min(sep_pos[0], sep_pos[1]).mean() + torch.min(sep_pos[2], sep_pos[3]).mean())*0.5 loss = loss_sep + loss_grd_sep err = loss + loss_grd * 0.5 return err, loss_sep, g,\ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def output_visuals(vis_rows, batch_data, outputs, args): # fetch data and predictions mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] frames = batch_data['frames'] infos = batch_data['infos'] pred_masks_ = outputs['pred_masks'] gt_masks_ = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # unwarp log scale N = args.num_mix #-1 B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype( np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): row_elements = [] # video names prefix = [] for n in range(N): prefix.append('-'.join( infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) imsave(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) row_elements += [{ 'text': prefix }, { 'image': filename_mixmag, 'audio': filename_mixwav }] # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n + 1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n + 1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n + 1)) gt_mag = magnitude2heatmap(gt_mag) pred_mag = magnitude2heatmap(pred_mag) imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) # output video frames_tensor = [ recover_rgb(frames[n][j, :, t]) for t in range(args.num_frames) ] frames_tensor = np.asarray(frames_tensor) path_video = os.path.join(args.vis, prefix, 'video{}.mp4'.format(n + 1)) save_video(path_video, frames_tensor, fps=args.frameRate / args.stride_frames) # combine gt video and audio filename_av = os.path.join(prefix, 'av{}.mp4'.format(n + 1)) combine_video_audio(path_video, os.path.join(args.vis, filename_gtwav), os.path.join(args.vis, filename_av)) row_elements += [{ 'video': filename_av }, { 'image': filename_predmag, 'audio': filename_predwav }, { 'image': filename_gtmag, 'audio': filename_gtwav }, { 'image': filename_predmask }, { 'image': filename_gtmask }] row_elements += [{'image': filename_weight}] vis_rows.append(row_elements)
def calc_metrics(batch_data, outputs, args): # meters sdr_mix_meter = AverageMeter() sdr_meter = AverageMeter() sir_meter = AverageMeter() sar_meter = AverageMeter() # fetch data and predictions mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] audios = batch_data['audios'] pred_masks_ = outputs['pred_masks'] # unwarp log scale N = 4 #args.num_mix-1 B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, pred_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() phase_mix = phase_mix.numpy() for n in range(N): pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) # save each component preds_wav = [None for n in range(N)] for n in range(N): # Predicted audio recovery pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction( pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) + 1e-6 # separation performance computes L = preds_wav[0].shape[0] gts_wav = [None for n in range(N)] valid = True for n in range(N): gts_wav[n] = audios[n][j, 0:L].numpy() + 1e-6 valid *= np.sum(np.abs(gts_wav[n])) > 1e-5 valid *= np.sum(np.abs(preds_wav[n])) > 1e-5 if valid: sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav), np.asarray(preds_wav), False) sdr_mix, _, _, _ = bss_eval_sources( np.asarray(gts_wav), np.asarray([mix_wav[0:L] for n in range(N)]), False) sdr_mix_meter.update(sdr_mix.mean()) sdr_meter.update(sdr.mean()) sir_meter.update(sir.mean()) sar_meter.update(sar.mean()) return [ sdr_mix_meter.average(), sdr_meter.average(), sir_meter.average(), sar_meter.average() ]
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], args.output_activation) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] frame_emb = batch_data['frame_emb'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # separating sound sound_size = feat_sound.size() B, C = sound_size[0], sound_size[1] pred_masks = [None for n in range(N)] for n in range(N): feat_img = frame_emb[n].float() feat_img = feat_img.view(B, 1, C) pred_masks[n] = torch.bmm(feat_img, feat_sound.view(B, C, -1)) \ .view(B, 1, *sound_size[2:]) pred_masks[n] = activate(pred_masks[n], args.output_activation) # loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def output_visuals_PosNeg(vis_rows, batch_data, masks_pos, masks_neg, idx_pos, idx_neg, pred_masks_, gt_masks_, mag_mix_, weight_, args): mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] frames = batch_data['frames'] infos = batch_data['infos'] # masks to cpu, numpy masks_pos = torch.squeeze(masks_pos, dim=1) masks_pos = masks_pos.cpu().float().numpy() masks_neg = torch.squeeze(masks_neg, dim=1) masks_neg = masks_neg.cpu().float().numpy() N = args.num_mix B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame//2+1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() idx_pos = int(idx_pos.detach().cpu().numpy()) idx_neg = int(idx_neg.detach().cpu().numpy()) for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype(np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) threshold = 0.5 # loop over each sample for j in range(B): row_elements = [] # video names prefix = [] for n in range(N): prefix.append('-'.join(infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') matplotlib.image.imsave(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) row_elements += [{'text': prefix}, {'image': filename_mixmag, 'audio': filename_mixwav}] # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_mag_ = mag_mix_[j, 0] * gt_masks_[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] pred_mag_ = mag_mix_[j, 0] * pred_masks_[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n+1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n+1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype(np.uint8) matplotlib.image.imsave(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n+1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n+1)) gt_mag = magnitude2heatmap(gt_mag_) pred_mag = magnitude2heatmap(pred_mag_) matplotlib.image.imsave(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) matplotlib.image.imsave(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n+1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n+1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) # save frame frames_tensor = recover_rgb(frames[idx_pos][j,:,int(args.num_frames//2)]) frames_tensor = np.asarray(frames_tensor) filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_pos+1)) matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor) frame = frames_tensor.copy() # get heatmap and overlay for postive pair height, width = masks_pos.shape[-2:] heatmap = np.zeros((height*16, width*16)) for i in range(height): for k in range(width): mask_pos = masks_pos[j] value = mask_pos[i,k] value = 0 if value < threshold else value ii = i * 16 jj = k * 16 heatmap[ii:ii + 16, jj:jj + 16] = value heatmap = (heatmap * 255).astype(np.uint8) filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_pos+1)) plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot') heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F) path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_pos+1)) cv2.imwrite(path_overlay, fin) # save frame frames_tensor = recover_rgb(frames[idx_neg][j,:,int(args.num_frames//2)]) frames_tensor = np.asarray(frames_tensor) filename_frame = os.path.join(prefix, 'frame{}.png'.format(idx_neg+1)) matplotlib.image.imsave(os.path.join(args.vis, filename_frame), frames_tensor) frame = frames_tensor.copy() # get heatmap and overlay for postive pair height, width = masks_neg.shape[-2:] heatmap = np.zeros((height*16, width*16)) for i in range(height): for k in range(width): mask_neg = masks_neg[j] value = mask_neg[i,k] value = 0 if value < threshold else value ii = i * 16 jj = k * 16 heatmap[ii:ii + 16, jj:jj + 16] = value heatmap = (heatmap * 255).astype(np.uint8) filename_heatmap = os.path.join(prefix, 'heatmap_{}_{}.jpg'.format(idx_pos+1, idx_neg+1)) plt.imsave(os.path.join(args.vis, filename_heatmap), heatmap, cmap='hot') heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) fin = cv2.addWeighted(heatmap, 0.5, frame, 0.5, 0, dtype = cv2.CV_32F) path_overlay = os.path.join(args.vis, prefix, 'overlay_{}_{}.jpg'.format(idx_pos+1, idx_neg+1)) cv2.imwrite(path_overlay, fin) vis_rows.append(row_elements)
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) gt_masks[1] = torch.mul(gt_masks[3], 0.) gt_masks[3] = torch.mul(gt_masks[3], 0.) # LOG magnitude log_mag_mix = torch.log1p(mag_mix).detach() log_mag0 = torch.log1p(mags[0]).detach() log_mag2 = torch.log1p(mags[2]).detach() # grounding feat_sound_ground = self.net_sound_ground(log_mag_mix) feat_frames_ground = [None for n in range(N)] for n in range(N): feat_frames_ground[n] = self.net_frame_ground.forward_multiframe( frames[n]) # Grounding for sep g_sep = [None for n in range(N)] x = [None for n in range(N)] for n in range(N): g_sep[n] = self.net_grounding(feat_sound_ground, feat_frames_ground[n]) x[n] = torch.softmax(g_sep[n].clone(), dim=-1) # Grounding module g_pos = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_pos1 = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_neg = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[0]) # Grounding for solo sound g_solo = [None for n in range(N)] g_solo[0] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[0]) g_solo[1] = self.net_grounding(self.net_sound_ground(log_mag0), feat_frames_ground[1]) g_solo[2] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[2]) g_solo[3] = self.net_grounding(self.net_sound_ground(log_mag2), feat_frames_ground[3]) for n in range(N): g_solo[n] = torch.softmax(g_solo[n], dim=-1) g = [ torch.softmax(g_pos, dim=-1), torch.softmax(g_neg, dim=-1), x, g_solo ] p = torch.zeros(B).cuda() n = torch.ones(B).cuda() cts_pos = torch.zeros(B).cuda() cts_pos1 = torch.zeros(B).cuda() for i in range(B): cts_pos[i] = self.cts(g_pos[i:i + 1], p[i:i + 1].long()) cts_pos1[i] = self.cts(g_pos1[i:i + 1], p[i:i + 1].long()) # 5. loss err = torch.min(cts_pos, cts_pos1).mean() + self.cts(g_neg, n.long()) \ #+ torch.min(self.cts(g_sep[0], p.long()), self.cts(g_sep[1], p.long())) + torch.min(self.cts(g_sep[2], p.long()), self.cts(g_sep[3], p.long())) return err, g
def forward(self, batch_data, args): ### prepare data mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 mag_mix_tmp = mag_mix.clone() N = args.num_mix B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! # Please notice that, gt_masks are unordered gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 2.) ### Minus part if 'Minus' not in self.mode: self.requires_grad(self.net_sound_M, False) self.requires_grad(self.net_frame_M, False) feat_frames = [None for n in range(N)] ordered_pred_masks = [None for n in range(N)] ordered_pred_mags = [None for n in range(N)] # Step1: obtain all the frame features # forward net_frame_M -> Bx1xC for n in range(N): log_mag_mix = torch.log(mag_mix) feat_frames[n] = self.net_frame_M.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], args.img_activation) # Step2: separate the sounds one by one # forward net_sound_M -> BxCxHxW if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, 256, warp=False)).to(args.device) index_record = [] for n in range(N): log_mag_mix = torch.log(mag_mix).detach() feat_sound = self.net_sound_M(log_mag_mix) _, C, H, W = feat_sound.shape feat_sound = feat_sound.view(B, C, -1) # obtain current separated sound energy_list = [] tmp_masks = [] tmp_pred_mags = [] for feat_frame in feat_frames: cur_pred_mask = torch.bmm(feat_frame.unsqueeze(1), feat_sound).view(B, 1, H, W) cur_pred_mask = activate(cur_pred_mask, args.output_activation) tmp_masks.append(cur_pred_mask) # Here we cut off the loss flow from Minus net to Plus net # in order to train more steadily if args.log_freq: cur_pred_mask_unwrap = F.grid_sample( cur_pred_mask.detach(), grid_unwarp) if args.binary_mask: cur_pred_mask_unwrap = (cur_pred_mask_unwrap > args.mask_thres).float() else: cur_pred_mask_unwrap = cur_pred_mask.detach() cur_pred_mag = cur_pred_mask_unwrap * mag_mix_tmp tmp_pred_mags.append(cur_pred_mag) energy_list.append( np.array(cur_pred_mag.view(B, -1).mean(dim=1).cpu().data)) total_energy = np.stack(energy_list, axis=1) # _, cur_index = torch.max(total_energy) cur_index = self.choose_max(index_record, total_energy) index_record.append(cur_index) masks = torch.stack(tmp_masks, dim=0) ordered_pred_masks[n] = masks[cur_index, list(range(B))] pred_mags = torch.stack(tmp_pred_mags, dim=0) ordered_pred_mags[n] = pred_mags[cur_index, list(range(B))] #log_mag_mix = log_mag_mix - log_mag_mix * pred_masks[n] mag_mix = mag_mix - mag_mix * ordered_pred_masks[n] + 1e-10 # just for swap pred_masks, in order to compute loss conveniently # since gt_masks are unordered, we must transfer ordered_pred_masks to unordered index_record = np.stack(index_record, axis=1) total_masks = torch.stack(ordered_pred_masks, dim=1) total_pred_mags = torch.stack(ordered_pred_mags, dim=1) unordered_pred_masks = [] unordered_pred_mags = [] for n in range(N): mask_index = np.where(index_record == n) if args.binary_mask: unordered_pred_masks.append(total_masks[mask_index]) unordered_pred_mags.append(total_pred_mags[mask_index]) else: unordered_pred_masks.append(total_masks[mask_index] * 2) unordered_pred_mags.append(total_pred_mags[mask_index]) ### Plus part if 'Plus' in self.mode: pre_sum = torch.zeros_like(unordered_pred_masks[0]).to(args.device) Plus_pred_masks = [] for n in range(N): unordered_pred_mag = unordered_pred_mags[n].log() unordered_pred_mag = F.grid_sample(unordered_pred_mag, grid_warp) input_concat = torch.cat((pre_sum, unordered_pred_mag), dim=1) residual_mask = activate(self.net_sound_P(input_concat), args.sound_activation) Plus_pred_masks.append(unordered_pred_masks[n] + residual_mask) pre_sum = pre_sum.sum(dim=1, keepdim=True).detach() unordered_pred_masks = Plus_pred_masks # loss if args.need_loss_ratio: err = 0 for n in range(N): err += self.crit(unordered_pred_masks[n], gt_masks[n], weight) / N * 2**(n - 1) else: err = self.crit(unordered_pred_masks, gt_masks, weight).reshape(1) if 'Minus' in self.mode: res_mag_mix = torch.exp(log_mag_mix) err_remain = torch.mean(weight * torch.clamp(res_mag_mix - 1e-2, min=0)) err += err_remain outputs = { 'pred_masks': unordered_pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight } return err, outputs
def forward(self, batch_data, args): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] feats = batch_data['feats'] coords = batch_data['coords'] mag_mix = mag_mix + 1e-10 N = args.num_mix FN = args.num_frames B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if args.log_freq: grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(args.device) mag_mix = F.grid_sample(mag_mix, grid_warp) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if args.weighted_loss: weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warpping! gt_masks = [None for n in range(N)] for n in range(N): if args.binary_mask: if N > 2: binary_masks = torch.ones_like(mags[n]) for m in range(N): binary_masks *= (mags[n] >= mags[m]) gt_masks[n] = binary_masks else: gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, args.sound_activation) # 2. forward net_vision feat_frames = [None for n in range(N)] for n in range(N): frames_features = [] for f in range(FN): sin = ME.SparseTensor( feats[n][f], coords[n][f].int(), allow_duplicate_coords=True) #Create SparseTensor frames_features.append(self.net_vision.forward(sin)) frames_features = torch.stack(frames_features) frames_features = frames_features.permute(1, 2, 0) if args.frame_pool == 'maxpool': feat_frame = F.adaptive_max_pool1d(frames_features, 1) elif args.frame_pool == 'avgpool': feat_frame = F.adaptive_avg_pool1d(frames_features, 1) feat_frames[n] = feat_frame.squeeze() feat_frames[n] = activate(feat_frames[n], args.vision_activation) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], args.output_activation) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def output_visuals(batch_data, outputs, args): mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] features = batch_data['feats'] coords = batch_data['coords'] infos = batch_data['infos'] pred_masks_ = outputs['pred_masks'] gt_masks_ = outputs['gt_masks'] mag_mix_ = outputs['mag_mix'] weight_ = outputs['weight'] # unwarp log scale N = args.num_mix FN = args.num_frames B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] gt_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, gt_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) gt_masks_linear[n] = F.grid_sample(gt_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] gt_masks_linear[n] = gt_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() mag_mix_ = mag_mix_.detach().cpu().numpy() phase_mix = phase_mix.numpy() weight_ = weight_.detach().cpu().numpy() for n in range(N): pred_masks_[n] = pred_masks_[n].detach().cpu().numpy() pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() gt_masks_[n] = gt_masks_[n].detach().cpu().numpy() gt_masks_linear[n] = gt_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_[n] = (pred_masks_[n] > args.mask_thres).astype( np.float32) pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): # video names prefix = [] for n in range(N): prefix.append('-'.join( infos[n][0][j].split('/')[-2:]).split('.')[0]) prefix = '+'.join(prefix) makedirs(os.path.join(args.vis, prefix)) # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) mix_amp = magnitude2heatmap(mag_mix_[j, 0]) weight = magnitude2heatmap(weight_[j, 0], log=False, scale=100.) filename_mixwav = os.path.join(prefix, 'mix.wav') filename_mixmag = os.path.join(prefix, 'mix.jpg') filename_weight = os.path.join(prefix, 'weight.jpg') imageio.imwrite(os.path.join(args.vis, filename_mixmag), mix_amp[::-1, :, :]) imageio.imwrite(os.path.join(args.vis, filename_weight), weight[::-1, :]) wavfile.write(os.path.join(args.vis, filename_mixwav), args.audRate, mix_wav) # save each component preds_wav = [None for n in range(N)] for n in range(N): # GT and predicted audio recovery gt_mag = mag_mix[j, 0] * gt_masks_linear[n][j, 0] gt_wav = istft_reconstruction(gt_mag, phase_mix[j, 0], hop_length=args.stft_hop) pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # output masks filename_gtmask = os.path.join(prefix, 'gtmask{}.jpg'.format(n + 1)) filename_predmask = os.path.join(prefix, 'predmask{}.jpg'.format(n + 1)) gt_mask = (np.clip(gt_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) pred_mask = (np.clip(pred_masks_[n][j, 0], 0, 1) * 255).astype( np.uint8) imageio.imwrite(os.path.join(args.vis, filename_gtmask), gt_mask[::-1, :]) imageio.imwrite(os.path.join(args.vis, filename_predmask), pred_mask[::-1, :]) # ouput spectrogram (log of magnitude, show colormap) filename_gtmag = os.path.join(prefix, 'gtamp{}.jpg'.format(n + 1)) filename_predmag = os.path.join(prefix, 'predamp{}.jpg'.format(n + 1)) gt_mag = magnitude2heatmap(gt_mag) pred_mag = magnitude2heatmap(pred_mag) imageio.imwrite(os.path.join(args.vis, filename_gtmag), gt_mag[::-1, :, :]) imageio.imwrite(os.path.join(args.vis, filename_predmag), pred_mag[::-1, :, :]) # output audio filename_gtwav = os.path.join(prefix, 'gt{}.wav'.format(n + 1)) filename_predwav = os.path.join(prefix, 'pred{}.wav'.format(n + 1)) wavfile.write(os.path.join(args.vis, filename_gtwav), args.audRate, gt_wav) wavfile.write(os.path.join(args.vis, filename_predwav), args.audRate, preds_wav[n]) #output pointclouds for f in range(FN): idx = torch.where(coords[n][f][:, 0] == j) path_point = os.path.join( args.vis, prefix, 'point{}_frame{}.ply'.format(n + 1, f + 1)) if args.rgbs_feature: colors = np.asarray(features[n][f][idx]) xyz = np.asarray(coords[n][f][idx][:, 1:4]) xyz = xyz * args.voxel_size save_points(path_point, xyz, colors) else: xyz = np.asarray(features[n][f][idx]) save_points(path_point, xyz)