def review(self, batch, model_out): pit_mse_loss = list() for mask, observation, target in zip(model_out, batch['Y_abs'], batch['X_abs']): pit_mse_loss.append( pt.ops.losses.pit_loss(mask * observation[:, None, :], target, axis=-2)) # Ideal Phase Sensitive loss pit_ips_loss = list() for mask, observation, target, cos_phase_diff in zip( model_out, batch['Y_abs'], batch['X_abs'], batch['cos_phase_difference']): pit_ips_loss.append( pt.ops.losses.pit_loss(mask * observation[:, None, :], target * cos_phase_diff, axis=-2)) losses = { 'pit_mse_loss': torch.mean(torch.stack(pit_mse_loss)), 'pit_ips_loss': torch.mean(torch.stack(pit_ips_loss)), } b = 0 # only print image of first example in a batch images = dict() images['observation'] = stft_to_image(batch['Y_abs'][b]) for i in range(model_out[b].shape[1]): images[f'mask_{i}'] = mask_to_image(model_out[b][:, i, :]) images[f'estimation_{i}'] = stft_to_image(batch['X_abs'][b][:, 0, :]) return dict(losses=losses, images=images)
def review(self, batch, model_out): # TODO: Maybe calculate only one loss? May be much faster. pit_mse_loss = list() for mask, observation, target in zip(model_out, batch['Y_abs'], batch['X_abs']): pit_mse_loss.append( pt.ops.losses.pit_loss(mask * observation[:, None, :], target, axis=-2)) pit_ips_loss = list() for mask, observation, target, cos_phase_diff in zip( model_out, batch['Y_abs'], batch['X_abs'], batch['cos_phase_difference']): estimation = mask * observation[:, None, :] pit_ips_loss.append( pt.ops.losses.pit_loss(estimation, target * cos_phase_diff, axis=-2)) pit_ips_clean_loss = list() for mask, observation, target, cos_phase_diff in zip( model_out, batch['Y_abs'], batch['X_clean'], batch['cos_phase_difference']): estimation = mask * observation[:, None, :] pit_ips_clean_loss.append( pt.ops.losses.pit_loss(estimation, target * cos_phase_diff, axis=-2)) binary_loss = list() for mask, target in zip( model_out, batch['target_mask'], ): binary_loss.append(pt.ops.losses.pit_loss(mask, target, axis=-2)) losses = { 'pit_mse_loss': torch.mean(torch.stack(pit_mse_loss)), 'pit_ips_loss': torch.mean(torch.stack(pit_ips_loss)), 'pit_ips_clean_loss': torch.mean(torch.stack(pit_ips_clean_loss)), 'binary_loss': torch.mean(torch.stack(binary_loss)), } b = 0 images = dict() images['observation'] = stft_to_image(batch['Y_abs'][b]) for i in range(model_out[b].shape[1]): images[f'mask_{i}'] = mask_to_image(model_out[b][:, i, :]) images[f'target_{i}'] = stft_to_image(batch['X_abs'][b][:, i, :]) images[f'estimation_{i}'] = stft_to_image(batch['Y_abs'][b] * model_out[b][:, i, :]) return dict(losses=losses, images=images)
def add_images(self, batch, output): speech_mask = output['speech_mask_prediction'] observation = batch['observation_abs'] images = dict() images['speech_mask'] = mask_to_image(speech_mask, True) images['observed_stft'] = stft_to_image(observation, True) if 'noise_mask_prediction' in output: noise_mask = output['noise_mask_prediction'] images['noise_mask'] = mask_to_image(noise_mask, True) if batch is not None and 'speech_mask_prediction' in batch: images['speech_mask_target'] = mask_to_image( batch['speech_mask_target'], True) if 'speech_mask_target' in batch: images['noise_mask_target'] = mask_to_image( batch['noise_mask_target'], True) return images
def add_images(self, batch, output): images = dict() if K.SPEECH_PRED in output: speech_pred = output[K.SPEECH_PRED][0] images['speech_pred'] = mask_to_image(speech_pred, True) if K.SPEECH_MASK_PRED in output: speech_mask = output[K.SPEECH_MASK_PRED][0] images['speech_mask'] = mask_to_image(speech_mask, True) observation = batch[K.OBSERVATION_ABS][0] images['observed_stft'] = stft_to_image(observation, True) if K.NOISE_MASK_PRED in output: noise_mask = output[K.NOISE_MASK_PRED][0] images['noise_mask'] = mask_to_image(noise_mask, True) if batch is not None and K.SPEECH_MASK_TARGET in batch: images['speech_mask_target'] = mask_to_image( batch[K.SPEECH_MASK_TARGET][0], True) if K.NOISE_MASK_TARGET in batch: images['noise_mask_target'] = mask_to_image( batch[K.NOISE_MASK_TARGET][0], True) return images