def forward(self, x, seq_len=None): if self.training or not self.track_running_stats: y, mean, power, n_values = normalize( x, gamma=self.gamma, beta=self.beta, statistics_axis=self.statistics_axis, batch_axis=self.batch_axis, sequence_axis=self.sequence_axis, seq_len=seq_len, shift=self.shift, scale=self.scale, eps=self.eps) if self.track_running_stats: self.num_tracked_values += n_values.data if self.momentum is None: momentum = 1 - n_values / self.num_tracked_values.data else: momentum = self.momentum if self.shift: self.running_mean *= momentum self.running_mean += (1 - momentum) * mean.data power = power.data + mean.data**2 if self.scale: self.running_power *= momentum self.running_power += (1 - momentum) * power.data if self.interpolation_factor > 0.: # perform straight through backpropagation # https://arxiv.org/pdf/1611.01144.pdf y_ = x if self.shift: y_ = y_ - self.running_mean if self.scale: y_ = y_ / torch.sqrt(self.runnning_var) y = y + self.interpolation_factor * (y_ - y).detach() y = y * compute_mask(x, seq_len, self.batch_axis, self.sequence_axis) else: y = x if self.shift: y = y - self.running_mean.data if self.scale: y = y / torch.sqrt(self.runnning_var) if self.gamma is not None: y = y * self.gamma if self.beta is not None: y = y + self.beta y = y * compute_mask(x, seq_len, self.batch_axis, self.sequence_axis) return y
def normalize_ref(x, gamma, beta, statistics_axis, batch_axis, sequence_axis, seq_len, shift, scale, eps): # compute mask if seq_len is not None: mask = compute_mask(x, seq_len, batch_axis, sequence_axis) else: mask = torch.ones_like(x) # compute statistics n_values = mask.sum(dim=statistics_axis, keepdim=True) x = x * mask mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) power = (x**2).sum(dim=statistics_axis, keepdim=True) / torch.max( n_values, torch.ones_like(n_values)) y = x if shift: y = y - mean power = power - mean**2 if scale: y = y / torch.sqrt(power + eps) if gamma is not None: assert gamma.dim() == x.dim(), gamma.shape y = y * gamma if beta is not None: assert beta.dim() == x.dim(), beta.shape y = y + beta return y * mask, mean, power, n_values
def forward(ctx, x, gamma, beta, statistics_axis, batch_axis, sequence_axis, seq_len, shift, scale, eps): ctx.statistics_axis = statistics_axis ctx.batch_axis = batch_axis ctx.sequence_axis = sequence_axis ctx.seq_len = seq_len ctx.shift = shift ctx.scale = scale ctx.eps = eps # compute mask mask = compute_mask(x, seq_len, batch_axis, sequence_axis) # compute statistics n_values = mask.sum(dim=statistics_axis, keepdim=True) x = x * mask mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max(n_values, torch.ones_like(n_values)) power = (x ** 2).sum(dim=statistics_axis, keepdim=True) / torch.max(n_values, torch.ones_like(n_values)) y = x if shift: y = y - mean power = power - mean**2 if scale: y = y / torch.sqrt(power + eps) ctx.save_for_backward(x, gamma, beta, mean, power) if gamma is not None: assert gamma.dim() == x.dim(), gamma.shape y = y * gamma if beta is not None: assert beta.dim() == x.dim(), beta.shape y = y + beta return y*mask, mean, power, n_values
def backward(ctx, grad_y, grad_mean, grad_power, _): if (grad_mean != 0).any() or (grad_power != 0).any(): raise NotImplementedError x, gamma, beta, mean, power = ctx.saved_tensors # compute mask if ctx.seq_len is not None: mask = compute_mask(x, ctx.seq_len, ctx.batch_axis, ctx.sequence_axis) else: mask = torch.ones_like(x) n_values = mask.sum(dim=ctx.statistics_axis, keepdim=True) grad_y = grad_y * mask x_hat = x scale = torch.sqrt(power + ctx.eps) if ctx.scale: x_hat = x_hat - mean if ctx.scale: x_hat = x_hat / scale if beta is None: grad_beta = None else: reduce_axis = [i for i in range(beta.dim()) if beta.shape[i] == 1] grad_beta = grad_y.sum(reduce_axis, keepdim=True) if gamma is None: grad_gamma = None grad_x_hat = grad_y else: reduce_axis = [ i for i in range(gamma.dim()) if gamma.shape[i] == 1 ] grad_gamma = (grad_y * x_hat).sum(reduce_axis, keepdim=True) grad_x_hat = grad_y * gamma if ctx.shift: x = (x - mean) * mask grad_mean_ = -grad_x_hat.sum(ctx.statistics_axis, keepdim=True) if ctx.scale: grad_power_ = (grad_x_hat * x).sum( ctx.statistics_axis, keepdim=True) * (-1 / 2) * (power + ctx.eps)**(-3 / 2) if ctx.shift: grad_mean_ = ( grad_mean_ / scale - 2 * grad_power_ * x.sum(ctx.statistics_axis, keepdim=True) / n_values) grad_x = grad_x_hat if ctx.scale: grad_x = grad_x / scale + grad_power_ * 2 * x / n_values if ctx.shift: grad_x = grad_x + grad_mean_ / n_values return grad_x * mask, grad_gamma, grad_beta, None, None, None, None, None, None, None
def forward(self, x, seq_len=None): mask = compute_mask(x, seq_len, self.batch_axis, self.sequence_axis) data_format = " ".join(list(self.data_format)) x_tmp = rearrange(x * mask, f'{data_format} -> {self.tmp_format}') mask_tmp = rearrange(mask, f'{data_format} -> {self.tmp_format}') tmp_shape = x_tmp.shape x_tmp = x_tmp.reshape((int(np.prod(tmp_shape[:-self.ndim])), *tmp_shape[-self.ndim:])).unsqueeze(1) mask_tmp = mask_tmp.reshape((int(np.prod(tmp_shape[:-self.ndim])), *tmp_shape[-self.ndim:])).unsqueeze(1) x_tmp = Pad(side='both', mode='constant')(x_tmp, size=np.array(self.window_size) - 1) mask_tmp = Pad(side='both', mode='constant')(mask_tmp, size=np.array(self.window_size) - 1) signal_fraction = self.pool_fn(mask_tmp) if self.shift: mean = self.pool_fn(x_tmp) / (signal_fraction + 1e-6) mean = mean.reshape(tmp_shape) mean = rearrange(mean, f'{self.tmp_format} -> {data_format}') if self.statistics_axis: mean = mean.mean(self.statistics_axis, keepdim=True) x = x - mean if self.scale: power = self.pool_fn(x_tmp**2) / (signal_fraction + 1e-6) power = power.reshape(tmp_shape) power = rearrange(power, f'{self.tmp_format} -> {data_format}') if self.statistics_axis: power = power.mean(self.statistics_axis, keepdim=True) if self.shift: power = (power - mean**2) # print(power.min(), power.max()) x = x / torch.sqrt(power + self.eps) if self.gamma is not None: x = x * self.gamma if self.beta is not None: x = x + self.beta return x * mask
def reverse_sequence(x, seq_len=None): """ >>> x, seq_len = (torch.cumsum(torch.ones((3,5,8)), dim=1), [4,5,2]) >>> reverse_sequence(x, seq_len) >>> reverse_sequence(reverse_sequence(x, seq_len), seq_len) Args: x: seq_len: Returns: """ if seq_len is None: return x.flip(1) else: T = x.shape[1] x = torch.cat((x, x), dim=1) x = torch.stack(tuple( [x[i, seq_len[i]:seq_len[i] + T].flip(0) for i in range(len(x))]), dim=0) mask = compute_mask(x, seq_len) return x * mask
def scaled_dot_product_attention(q, k, v, seq_len=None, bidirectional=False): """ >>> q = torch.zeros((2, 3, 4)) >>> k = torch.zeros((2, 6, 4)) >>> v = torch.randn((2, 6, 8)) >>> x = scaled_dot_product_attention(q, k, v) >>> x.shape torch.Size([2, 3, 8]) >>> q = torch.zeros((2, 6, 4)) >>> x = scaled_dot_product_attention(q, k, v, causal=True) >>> (x[0,0] == v[0,0]).all() tensor(1, dtype=torch.uint8) >>> (torch.abs(x[0,-1] - v[0].mean(0)) < 1e-6).all() tensor(1, dtype=torch.uint8) >>> x = scaled_dot_product_attention(q, k, v, seq_len=[6,4]) """ y = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) if not bidirectional: mask = get_causal_mask(y) y = y + torch.log((mask > 0).float()) elif seq_len is not None: mask = compute_mask(y, seq_len, seq_axis=-1) y = y + torch.log((mask > 0).float()) return torch.softmax(y, dim=-1) @ v