Example #1
0
 def forward(self, x, seq_len=None):
     if self.training or not self.track_running_stats:
         y, mean, power, n_values = normalize(
             x,
             gamma=self.gamma,
             beta=self.beta,
             statistics_axis=self.statistics_axis,
             batch_axis=self.batch_axis,
             sequence_axis=self.sequence_axis,
             seq_len=seq_len,
             shift=self.shift,
             scale=self.scale,
             eps=self.eps)
         if self.track_running_stats:
             self.num_tracked_values += n_values.data
             if self.momentum is None:
                 momentum = 1 - n_values / self.num_tracked_values.data
             else:
                 momentum = self.momentum
             if self.shift:
                 self.running_mean *= momentum
                 self.running_mean += (1 - momentum) * mean.data
                 power = power.data + mean.data**2
             if self.scale:
                 self.running_power *= momentum
                 self.running_power += (1 - momentum) * power.data
             if self.interpolation_factor > 0.:
                 # perform straight through backpropagation
                 # https://arxiv.org/pdf/1611.01144.pdf
                 y_ = x
                 if self.shift:
                     y_ = y_ - self.running_mean
                 if self.scale:
                     y_ = y_ / torch.sqrt(self.runnning_var)
                 y = y + self.interpolation_factor * (y_ - y).detach()
                 y = y * compute_mask(x, seq_len, self.batch_axis,
                                      self.sequence_axis)
     else:
         y = x
         if self.shift:
             y = y - self.running_mean.data
         if self.scale:
             y = y / torch.sqrt(self.runnning_var)
         if self.gamma is not None:
             y = y * self.gamma
         if self.beta is not None:
             y = y + self.beta
         y = y * compute_mask(x, seq_len, self.batch_axis,
                              self.sequence_axis)
     return y
Example #2
0
def normalize_ref(x, gamma, beta, statistics_axis, batch_axis, sequence_axis,
                  seq_len, shift, scale, eps):
    # compute mask
    if seq_len is not None:
        mask = compute_mask(x, seq_len, batch_axis, sequence_axis)
    else:
        mask = torch.ones_like(x)

    # compute statistics
    n_values = mask.sum(dim=statistics_axis, keepdim=True)
    x = x * mask
    mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max(
        n_values, torch.ones_like(n_values))
    power = (x**2).sum(dim=statistics_axis, keepdim=True) / torch.max(
        n_values, torch.ones_like(n_values))
    y = x
    if shift:
        y = y - mean
        power = power - mean**2
    if scale:
        y = y / torch.sqrt(power + eps)

    if gamma is not None:
        assert gamma.dim() == x.dim(), gamma.shape
        y = y * gamma
    if beta is not None:
        assert beta.dim() == x.dim(), beta.shape
        y = y + beta
    return y * mask, mean, power, n_values
Example #3
0
    def forward(ctx, x, gamma, beta, statistics_axis, batch_axis, sequence_axis, seq_len, shift, scale, eps):
        ctx.statistics_axis = statistics_axis
        ctx.batch_axis = batch_axis
        ctx.sequence_axis = sequence_axis
        ctx.seq_len = seq_len
        ctx.shift = shift
        ctx.scale = scale
        ctx.eps = eps

        # compute mask
        mask = compute_mask(x, seq_len, batch_axis, sequence_axis)

        # compute statistics
        n_values = mask.sum(dim=statistics_axis, keepdim=True)
        x = x * mask
        mean = x.sum(dim=statistics_axis, keepdim=True) / torch.max(n_values, torch.ones_like(n_values))
        power = (x ** 2).sum(dim=statistics_axis, keepdim=True) / torch.max(n_values, torch.ones_like(n_values))
        y = x
        if shift:
            y = y - mean
            power = power - mean**2
        if scale:
            y = y / torch.sqrt(power + eps)
        ctx.save_for_backward(x, gamma, beta, mean, power)

        if gamma is not None:
            assert gamma.dim() == x.dim(), gamma.shape
            y = y * gamma
        if beta is not None:
            assert beta.dim() == x.dim(), beta.shape
            y = y + beta
        return y*mask, mean, power, n_values
Example #4
0
    def backward(ctx, grad_y, grad_mean, grad_power, _):
        if (grad_mean != 0).any() or (grad_power != 0).any():
            raise NotImplementedError
        x, gamma, beta, mean, power = ctx.saved_tensors
        # compute mask
        if ctx.seq_len is not None:
            mask = compute_mask(x, ctx.seq_len, ctx.batch_axis,
                                ctx.sequence_axis)
        else:
            mask = torch.ones_like(x)
        n_values = mask.sum(dim=ctx.statistics_axis, keepdim=True)

        grad_y = grad_y * mask
        x_hat = x
        scale = torch.sqrt(power + ctx.eps)
        if ctx.scale:
            x_hat = x_hat - mean
        if ctx.scale:
            x_hat = x_hat / scale
        if beta is None:
            grad_beta = None
        else:
            reduce_axis = [i for i in range(beta.dim()) if beta.shape[i] == 1]
            grad_beta = grad_y.sum(reduce_axis, keepdim=True)
        if gamma is None:
            grad_gamma = None
            grad_x_hat = grad_y
        else:
            reduce_axis = [
                i for i in range(gamma.dim()) if gamma.shape[i] == 1
            ]
            grad_gamma = (grad_y * x_hat).sum(reduce_axis, keepdim=True)
            grad_x_hat = grad_y * gamma
        if ctx.shift:
            x = (x - mean) * mask
            grad_mean_ = -grad_x_hat.sum(ctx.statistics_axis, keepdim=True)
        if ctx.scale:
            grad_power_ = (grad_x_hat * x).sum(
                ctx.statistics_axis,
                keepdim=True) * (-1 / 2) * (power + ctx.eps)**(-3 / 2)
            if ctx.shift:
                grad_mean_ = (
                    grad_mean_ / scale - 2 * grad_power_ *
                    x.sum(ctx.statistics_axis, keepdim=True) / n_values)

        grad_x = grad_x_hat
        if ctx.scale:
            grad_x = grad_x / scale + grad_power_ * 2 * x / n_values
        if ctx.shift:
            grad_x = grad_x + grad_mean_ / n_values
        return grad_x * mask, grad_gamma, grad_beta, None, None, None, None, None, None, None
Example #5
0
    def forward(self, x, seq_len=None):
        mask = compute_mask(x, seq_len, self.batch_axis, self.sequence_axis)
        data_format = " ".join(list(self.data_format))
        x_tmp = rearrange(x * mask, f'{data_format} -> {self.tmp_format}')
        mask_tmp = rearrange(mask, f'{data_format} -> {self.tmp_format}')
        tmp_shape = x_tmp.shape
        x_tmp = x_tmp.reshape((int(np.prod(tmp_shape[:-self.ndim])),
                               *tmp_shape[-self.ndim:])).unsqueeze(1)
        mask_tmp = mask_tmp.reshape((int(np.prod(tmp_shape[:-self.ndim])),
                                     *tmp_shape[-self.ndim:])).unsqueeze(1)
        x_tmp = Pad(side='both',
                    mode='constant')(x_tmp,
                                     size=np.array(self.window_size) - 1)
        mask_tmp = Pad(side='both',
                       mode='constant')(mask_tmp,
                                        size=np.array(self.window_size) - 1)
        signal_fraction = self.pool_fn(mask_tmp)
        if self.shift:
            mean = self.pool_fn(x_tmp) / (signal_fraction + 1e-6)
            mean = mean.reshape(tmp_shape)
            mean = rearrange(mean, f'{self.tmp_format} -> {data_format}')
            if self.statistics_axis:
                mean = mean.mean(self.statistics_axis, keepdim=True)
            x = x - mean
        if self.scale:
            power = self.pool_fn(x_tmp**2) / (signal_fraction + 1e-6)
            power = power.reshape(tmp_shape)
            power = rearrange(power, f'{self.tmp_format} -> {data_format}')
            if self.statistics_axis:
                power = power.mean(self.statistics_axis, keepdim=True)
            if self.shift:
                power = (power - mean**2)
            # print(power.min(), power.max())
            x = x / torch.sqrt(power + self.eps)

        if self.gamma is not None:
            x = x * self.gamma
        if self.beta is not None:
            x = x + self.beta
        return x * mask
Example #6
0
def reverse_sequence(x, seq_len=None):
    """
    >>> x, seq_len = (torch.cumsum(torch.ones((3,5,8)), dim=1), [4,5,2])
    >>> reverse_sequence(x, seq_len)
    >>> reverse_sequence(reverse_sequence(x, seq_len), seq_len)
    Args:
        x:
        seq_len:

    Returns:

    """
    if seq_len is None:
        return x.flip(1)
    else:
        T = x.shape[1]
        x = torch.cat((x, x), dim=1)
        x = torch.stack(tuple(
            [x[i, seq_len[i]:seq_len[i] + T].flip(0) for i in range(len(x))]),
                        dim=0)
        mask = compute_mask(x, seq_len)
    return x * mask
Example #7
0
def scaled_dot_product_attention(q, k, v, seq_len=None, bidirectional=False):
    """
    >>> q = torch.zeros((2, 3, 4))
    >>> k = torch.zeros((2, 6, 4))
    >>> v = torch.randn((2, 6, 8))
    >>> x = scaled_dot_product_attention(q, k, v)
    >>> x.shape
    torch.Size([2, 3, 8])
    >>> q = torch.zeros((2, 6, 4))
    >>> x = scaled_dot_product_attention(q, k, v, causal=True)
    >>> (x[0,0] == v[0,0]).all()
    tensor(1, dtype=torch.uint8)
    >>> (torch.abs(x[0,-1] - v[0].mean(0)) < 1e-6).all()
    tensor(1, dtype=torch.uint8)
    >>> x = scaled_dot_product_attention(q, k, v, seq_len=[6,4])
    """
    y = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1])
    if not bidirectional:
        mask = get_causal_mask(y)
        y = y + torch.log((mask > 0).float())
    elif seq_len is not None:
        mask = compute_mask(y, seq_len, seq_axis=-1)
        y = y + torch.log((mask > 0).float())
    return torch.softmax(y, dim=-1) @ v