def boundaries_detection(self, inputs): y_fwd, y_bwd, seq_len_y, *_ = self.forward(inputs) seq_mask = compute_mask(y_fwd, seq_len_y, batch_axis=0, sequence_axis=-1) return torch.minimum(y_fwd * seq_mask, y_bwd * seq_mask), seq_len_y
def __call__(self, x, seq_len=None): if seq_len is None: x = x.sum(self.axis, keepdim=self.keepdims) else: mask = compute_mask(x, seq_len, 0, self.axis) x = (x * mask).sum(dim=self.axis, keepdim=self.keepdims) return x
def normalize_ref(x, gamma, beta, statistics_axis, batch_axis, sequence_axis, seq_len, shift, scale, eps): # compute mask if seq_len is not None: mask = compute_mask(x, seq_len, batch_axis, sequence_axis) else: mask = torch.ones_like(x) # compute statistics n_values = mask.sum(dim=statistics_axis, keepdim=True) x = x * mask mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) power = (x**2).sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) y = x if shift: y = y - mean power_scale = power - mean**2 else: power_scale = power if scale: y = y / torch.sqrt(power_scale + eps) if gamma is not None: assert gamma.dim() == x.dim(), gamma.shape y = y * gamma if beta is not None: assert beta.dim() == x.dim(), beta.shape y = y + beta return y * mask, mean, power, n_values
def _running_norm(self, x, sequence_lengths): if self.shift: x = x - self.running_mean.detach() if self.scale: x = x / torch.sqrt(self.running_var.detach() + self.eps) if self.gamma is not None: x = x * self.gamma if self.beta is not None: x = x + self.beta return x * compute_mask(x, sequence_lengths, self.batch_axis, self.sequence_axis)
def mask_and_compute_stats(x, sequence_lengths, statistics_axis, batch_axis, sequence_axis): # compute mask mask = compute_mask(x, sequence_lengths, batch_axis, sequence_axis) # compute statistics n_values = mask.sum(dim=statistics_axis, keepdim=True) x = x * mask mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) power = (x**2).sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) return x, mask, mean, power, n_values
def inverse(self, x, sequence_lengths=None): if not self.track_running_stats: raise NotImplementedError if self.beta is not None: x = x - self.beta if self.gamma is not None: x = x / self.gamma if self.scale: x = torch.sqrt(self.running_var.detach() + self.eps) * x if self.shift: x = x + self.running_mean.detach() x = x * compute_mask(x, sequence_lengths, self.batch_axis, self.sequence_axis) return x
def backward(ctx, grad_y, grad_mean, grad_power, _): # equations from https://arxiv.org/abs/1502.03167 if (grad_mean != 0).any() or (grad_power != 0).any(): raise NotImplementedError x, gamma, beta, mean, power_scale = ctx.saved_tensors # compute mask mask = compute_mask(x, ctx.seq_len, ctx.batch_axis, ctx.sequence_axis) n_values = mask.sum(dim=ctx.statistics_axis, keepdim=True) grad_y = grad_y * mask x_hat = x scale = torch.sqrt(power_scale + ctx.eps) if ctx.shift: x_hat = x_hat - mean if ctx.scale: x_hat = x_hat / scale if beta is None: grad_beta = None else: reduce_axis = [i for i in range(beta.dim()) if beta.shape[i] == 1] grad_beta = grad_y.sum(reduce_axis, keepdim=True) if gamma is None: grad_gamma = None grad_x_hat = grad_y else: reduce_axis = [i for i in range(gamma.dim()) if gamma.shape[i] == 1] grad_gamma = (grad_y * x_hat).sum(reduce_axis, keepdim=True) grad_x_hat = grad_y * gamma if ctx.shift: x = (x - mean) * mask grad_mean_ = -grad_x_hat.sum(ctx.statistics_axis, keepdim=True) if ctx.scale: grad_power_ = ( (grad_x_hat * x).sum(ctx.statistics_axis, keepdim=True) * (-1 / 2) * (power_scale + ctx.eps) ** (-3 / 2) ) if ctx.shift: grad_mean_ = ( grad_mean_ / scale - 2 * grad_power_ * x.sum(ctx.statistics_axis, keepdim=True) / n_values ) grad_x = grad_x_hat if ctx.scale: grad_x = grad_x / scale + grad_power_ * 2 * x / n_values if ctx.shift: grad_x = grad_x + grad_mean_ / n_values return grad_x * mask, grad_gamma, grad_beta, None, None, None, None, None, None, None
def forward(self, x, seq_len=None): # pack batch shape = x.size() x = x.reshape(1, -1, shape[-1]) assert self.width >= 3, self.width n = (self.width - 1) // 2 kernel = self.kernel.repeat(x.shape[1], 1, 1) y = torch.nn.functional.conv1d(x, kernel, groups=x.shape[1]) y = torch.nn.functional.pad(y, [n, n], mode="constant") # unpack batch y = y.reshape(shape) if seq_len is not None: y = y * compute_mask(y, np.array(seq_len) - n, batch_axis=0, sequence_axis=-1) return y
def reverse_sequence(x, seq_len=None): """ >>> x, seq_len = (torch.cumsum(torch.ones((3,5,8)), dim=1), [4,5,2]) >>> reverse_sequence(x, seq_len) >>> reverse_sequence(reverse_sequence(x, seq_len), seq_len) Args: x: seq_len: Returns: """ if seq_len is None: return x.flip(1) else: T = x.shape[1] x = torch.cat((x, x), dim=1) x = torch.stack(tuple( [x[i, seq_len[i]:seq_len[i] + T].flip(0) for i in range(len(x))]), dim=0) mask = compute_mask(x, seq_len) return x * mask
def scaled_dot_product_attention(q, k, v, seq_len=None, bidirectional=False): """ >>> q = torch.zeros((2, 3, 4)) >>> k = torch.zeros((2, 6, 4)) >>> v = torch.randn((2, 6, 8)) >>> x = scaled_dot_product_attention(q, k, v) >>> x.shape torch.Size([2, 3, 8]) >>> q = torch.zeros((2, 6, 4)) >>> x = scaled_dot_product_attention(q, k, v, causal=True) >>> (x[0,0] == v[0,0]).all() tensor(1, dtype=torch.uint8) >>> (torch.abs(x[0,-1] - v[0].mean(0)) < 1e-6).all() tensor(1, dtype=torch.uint8) >>> x = scaled_dot_product_attention(q, k, v, seq_len=[6,4]) """ y = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) if not bidirectional: mask = get_causal_mask(y) y = y + torch.log((mask > 0).float()) elif seq_len is not None: mask = compute_mask(y, seq_len, sequence_axis=-1) y = y + torch.log((mask > 0).float()) return torch.softmax(y, dim=-1) @ v
def __call__(self, x, seq_len=None): if seq_len is not None: mask = compute_mask(x, seq_len, 0, self.axis) x = (x + torch.log(mask)) x = x.max(self.axis, keepdim=self.keepdims) return x
def inference( model, method, dataset, device, max_segment_length=None, segment_overlap=0, merge_score_segments=False, score_segment_overlap=None, model_kwargs=None, medfilt_length=1, stepfilt_length=None, apply_mask=False, masks=None, post_processing_fn=None, timestamps=None, event_classes=None, score_storage_dir=None, ): if not isinstance(model, (list, tuple)): model = [model] if model_kwargs is None: model_kwargs = {} if not isinstance(model_kwargs, (list, tuple)): model_kwargs = len(model)*[model_kwargs] else: assert len(model_kwargs) == len(model), (len(model), len(model_kwargs)) medfilt_length = np.array(medfilt_length, dtype=np.int) apply_mask = np.array(apply_mask, dtype=np.bool) for i in range(len(model)): assert hasattr(model[i], method), (model[i], method) model[i].to(device) model[i].eval() scores = {} with torch.no_grad(): score_cache = {} for batch in tqdm(dataset): if 'weak_targets' in batch: batch.pop('weak_targets') if 'boundary_targets' in batch: batch.pop('boundary_targets') if 'strong_targets' in batch: batch.pop('strong_targets') if max_segment_length is not None: input_segments = segment_batch( batch, max_length=max_segment_length, overlap=segment_overlap ) else: input_segments = [batch] for segment in input_segments: segment = model[0].example_to_device(segment, device) segment_scores = [] seq_len = None for i in range(len(model)): yi, seq_len_i = getattr(model[i], method)( segment, **model_kwargs[i]) segment_scores.append(yi.detach().cpu().numpy()) if i == 0: seq_len = seq_len_i else: assert (seq_len_i == seq_len).all(), ( seq_len, seq_len_i) segment_scores = np.mean(segment_scores, axis=0) sequence_mask = compute_mask( torch.from_numpy(segment_scores), seq_len, batch_axis=0, sequence_axis=-1, ).numpy() segment_scores = segment_scores * sequence_mask # median filtering: segment_scores = filtering( segment_scores, medfilt, medfilt_length) if stepfilt_length is not None: # boundary filtering: stepfilt_length = np.array(stepfilt_length, dtype=np.int) segment_scores = filtering( segment_scores, boundariesfilt, stepfilt_length) # separate examples within batch if post_processing_fn is None: def post_processing_fn(x): return x score_cache.update({ audio_id: post_processing_fn( segment_scores[i, ..., :sl].swapaxes(-2, -1) ) for i, (audio_id, sl) in enumerate(zip( segment['example_id'], seq_len)) }) # applying mask allows to, e.g, mask SED score by tags. if apply_mask.any(): assert masks is not None for audio_id in score_cache: assert audio_id in masks, audio_id if apply_mask.ndim == 2: apply_mask = apply_mask[..., None, :] # elif apply_mask.ndim > 2: # raise ValueError( # f'apply_mask must be 0-,1- or 2-dimensional ' # f'but shape {apply_mask.shape} was given.' # ) mask = np.maximum(masks[audio_id], 1 - apply_mask) score_cache[audio_id] *= mask if merge_score_segments: if '_!segment!_' in segment['example_id'][0]: seg_idx, n_segments = segment['example_id'][0].split('_!segment!_')[-1].split('_') seg_idx = int(seg_idx) n_segments = int(n_segments) if seg_idx == n_segments - 1: score_cache = merge_segments( score_cache, segment_overlap=segment_overlap if score_segment_overlap is None else score_segment_overlap ) else: continue if ( timestamps is not None or event_classes is not None or score_storage_dir is not None ): assert timestamps is not None assert event_classes is not None score_cache = scores_to_dataframes( score_cache, timestamps, event_classes, score_storage_dir, ) if score_storage_dir is None: if not scores: scores = score_cache elif isinstance(scores, (list, tuple)): assert isinstance(score_cache, (list, tuple)) assert len(score_cache) == len(scores) for i in range(len(scores)): scores[i].update(score_cache[i]) else: assert isinstance(scores, dict) assert isinstance(score_cache, dict) scores.update(score_cache) else: scores = score_cache score_cache = {} return scores