Exemple #1
0
    def visualize_weights(self, rows, cols, col):
        w_conv_d, w_prop_e, w_pow_e, w_conv_e, w_channel_e = self.prepare_weights(
        )

        w_conv_d = rearrange(w_conv_d,
                             'o i (h w) -> o i h w',
                             h=self.kernel_size)
        w_spatial_d = reduce(w_conv_d, 'o i h w -> i h w', 'sum')
        w_channel_d = reduce(w_conv_d, 'o i h w -> o i', 'sum')

        w_conv_e = rearrange(w_conv_e,
                             '(o d2) (i d1) h w -> o d2 i d1 h w',
                             d1=4,
                             d2=4)
        w_spatial_e = reduce(w_conv_e, 'o d2 i d1 h w -> i d1 h w', 'sum')
        if self.n_in >= self.n_out:
            w_channel_e = reduce(w_conv_e, 'o d2 i d1 h w -> o i', 'sum')
        w_dir_e = reduce(w_conv_e, 'o d2 i d1 h w -> d2 i d1', 'sum')

        idx = col

        for c in range(self.n_in):
            ax = plt.subplot(rows, cols, idx)
            plt.imshow(w_pow_e[c, :, :])
            for (j, i), label in np.ndenumerate(w_pow_e[c, :, :]):
                ax.text(i,
                        j,
                        '{:.02f}'.format(label),
                        ha='center',
                        va='center',
                        fontsize=min(10, 60 / rows),
                        color='tomato')
            plt.xticks(
                [0, 1, 2, 3],
                [r'$^a / _b$', r'$\frac{a}{b}$', r'$_b \backslash ^a$', 'b|a'],
                fontsize=min(10, 100 / rows))
            plt.yticks([0, 1], [r'$\frac{a}{b}$', r'$\frac{b}{a}$'],
                       fontsize=min(10, 100 / rows))
            ax.tick_params(length=0)
            idx += cols
        for c in range(self.n_in, self.n_out):
            idx += cols

        ax = plt.subplot(rows, cols, idx)
        plt.imshow(w_prop_e[0, :, :, 0, 0])
        for (j, i), label in np.ndenumerate(w_prop_e[0, :, :, 0, 0]):
            ax.text(i,
                    j,
                    '{:.02f}'.format(label),
                    ha='center',
                    va='center',
                    fontsize=min(10, 80 / rows),
                    color='tomato')
        plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                   fontsize=min(10, 100 / rows))
        plt.yticks([])
        if self.n_in > 1:
            plt.ylabel('c', fontsize=min(10, 100 / rows))
        ax.tick_params(length=0)
        idx += cols

        for c in range(self.n_in):
            for d in range(4):
                if self.kernel_size > 1:
                    ax = plt.subplot(rows, cols, idx)
                    plt.imshow(w_spatial_e[c, d, :, :])
                    plt.xlabel('w', fontsize=min(10, 100 / rows))
                    plt.ylabel('h', fontsize=min(10, 100 / rows))
                    plt.xticks([])
                    plt.yticks([])
                idx += cols
        idx += max(0, self.n_out - self.n_in) * 4 * cols

        if self.n_in > 1:
            ax = plt.subplot(rows, cols, idx)
            plt.xlabel('in', fontsize=min(10, 100 / rows))
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            plt.imshow(w_channel_e)
            plt.xticks([])
            plt.yticks([])
            idx += cols
        elif self.n_out > 1:
            idx += cols

        for c in range(self.n_in):
            ax = plt.subplot(rows, cols, idx)
            plt.imshow(w_dir_e[:, c, :])
            plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                       fontsize=min(10, 100 / rows))
            plt.yticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                       fontsize=min(10, 100 / rows))
            ax.tick_params(length=0)
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            idx += cols
        idx += max(0, self.n_out - self.n_in) * cols

        idx += cols

        if self.kernel_size > 1:
            for c in range(self.n_in):
                ax = plt.subplot(rows, cols, idx)
                plt.imshow(w_spatial_d[c, :, :])
                plt.xlabel('w', fontsize=min(10, 100 / rows))
                plt.ylabel('h', fontsize=min(10, 100 / rows))
                plt.xticks([])
                plt.yticks([])
                idx += cols
            idx += max(0, self.n_out - self.n_in) * cols
        else:
            idx += max(self.n_in, self.n_out) * cols

        if self.n_in > 1:
            ax = plt.subplot(rows, cols, idx)
            plt.xlabel('in', fontsize=min(10, 100 / rows))
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            plt.imshow(w_channel_d[:, :])
            plt.xticks([])
            plt.yticks([])
            idx += cols
        elif self.n_out > 1:
            idx += cols
Exemple #2
0
    def forward(self, x, mask=None, return_attn=False):
        b, n, _, h, m, iters, eps = *x.shape, self.heads, self.m, self.pinv_iterations, self.eps

        # pad so that sequence can be evenly divided into m landmarks

        remainder = n % m
        if remainder > 0:
            padding = m - (n % m)
            x = F.pad(x, (0, 0, 0, padding), value=0)

            if exists(mask):
                mask = F.pad(mask, (0, padding), value=False)

        # derive query, keys, values

        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))

        # set masked positions to 0 in queries, keys, values

        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n')
            q, k, v = map(lambda t: t * mask[..., None], (q, k, v))

        q *= self.scale

        # generate landmarks by sum reduction, and then calculate mean using the mask

        l = ceil(n / m)
        landmark_einops_eq = '... (n l) d -> ... n d'
        q_landmarks = reduce(q, landmark_einops_eq, 'sum', l=l)
        k_landmarks = reduce(k, landmark_einops_eq, 'sum', l=l)

        # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean

        divisor = l
        if exists(mask):
            mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l=l)
            divisor = mask_landmarks_sum[..., None] + eps
            mask_landmarks = mask_landmarks_sum > 0

        # masked mean (if mask exists)

        q_landmarks /= divisor
        k_landmarks /= divisor

        # similarities

        einops_eq = '... i d, ... j d -> ... i j'
        sim1 = einsum(einops_eq, q, k_landmarks)
        sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
        sim3 = einsum(einops_eq, q_landmarks, k)

        # masking

        if exists(mask):
            mask_value = -torch.finfo(q.dtype).max
            sim1.masked_fill_(
                ~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim2.masked_fill_(
                ~(mask_landmarks[..., None] * mask_landmarks[..., None, :]),
                mask_value)
            sim3.masked_fill_(
                ~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

        # eq (15) in the paper

        attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1),
                                  (sim1, sim2, sim3))
        attn2_inv = moore_penrose_iter_pinv(attn2, iters)
        attn = attn1 @ attn2_inv @ attn3

        # aggregate

        out = einsum('... i j, ... j d -> ... i d', attn, v)

        # add depth-wise conv residual of values

        if self.residual:
            out += self.res_conv(v)

        # merge and combine heads

        out = rearrange(out, 'b h n d -> b n (h d)', h=h)
        out = self.to_out(out)
        out = out[:, :n]

        if return_attn:
            return out, attn

        return out
Exemple #3
0
    def prepare_weights(self):
        # enforce limits
        w_channel_d = F.softplus(self.w_channel_d)
        w_spatial_d = F.softplus(self.w_spatial_d)
        w_prop_e = torch.sigmoid(self.w_prop_e)
        w_pow_e = F.softplus(self.w_pow_e)
        w_dir_e = F.softplus(self.w_dir_e)
        w_channel_e = F.softplus(self.w_channel_e)
        if self.symmentric:
            w_spatial_e_0 = F.softplus(self.w_spatial_e_0)
            w_spatial_e_1 = F.softplus(self.w_spatial_e_1)
            w_spatial_e_3 = F.softplus(self.w_spatial_e_3)
        else:
            w_spatial_e = F.softplus(self.w_spatial_e)

        if self.no_prop:
            w_prop_e = torch.zeros_like(w_prop_e)

        # enforce symmetry by weight sharing
        if self.symmentric:
            w_spatial_d = torch.cat(
                (w_spatial_d,
                 w_spatial_d[:, :, :self.kernel_size // 2].flip(dims=(2, ))),
                dim=2)
            #0 => /
            #1 => -
            #2 => \
            #3 => |
            # 1, 3 are symmetric; 2 is a mirror of 0
            w_spatial_e = torch.stack((
                w_spatial_e_0,
                torch.cat(
                    (w_spatial_e_1, w_spatial_e_1[:, :, :-1].flip(dims=(2, ))),
                    dim=2), w_spatial_e_0.flip(dims=(2, )),
                torch.cat(
                    (w_spatial_e_3, w_spatial_e_3[:, :, :-1].flip(dims=(2, ))),
                    dim=2)),
                                      dim=1)
            # connect directions to each other; with connections from or to 0 and 2 sharing the same weights
            w_dir_e = w_dir_e.unbind(1)
            w_dir_e = torch.stack(
                (torch.stack(
                    (w_dir_e[0], w_dir_e[1], w_dir_e[2], w_dir_e[3]), dim=1),
                 torch.stack(
                     (w_dir_e[4], w_dir_e[5], w_dir_e[4], w_dir_e[6]), dim=1),
                 torch.stack(
                     (w_dir_e[2], w_dir_e[1], w_dir_e[0], w_dir_e[3]), dim=1),
                 torch.stack(
                     (w_dir_e[7], w_dir_e[8], w_dir_e[7], w_dir_e[9]), dim=1)),
                dim=0)
            # 1 w_pow_e per side; 0 and 2 are mirrors; 3 has symmetric sides
            w_pow_e = w_pow_e.unbind(1)
            w_pow_e = torch.stack((torch.stack(
                (w_pow_e[0], w_pow_e[1]),
                dim=1), torch.stack(
                    (w_pow_e[2], w_pow_e[3]),
                    dim=1), torch.stack(
                        (w_pow_e[1], w_pow_e[0]),
                        dim=1), torch.stack((w_pow_e[4], w_pow_e[4]), dim=1)),
                                  dim=2)
            w_prop_e = torch.cat((w_prop_e[:, :, 1, None, :, :], w_prop_e),
                                 dim=2)

        # normalize by output channel
        # technically not needed for d, but here for consistency
        w_channel_d = w_channel_d / reduce(w_channel_d, 'o i -> o 1', 'sum')
        w_spatial_d = w_spatial_d / reduce(w_spatial_d, 'i h w -> i 1 1',
                                           'sum')
        w_channel_e = w_channel_e / reduce(w_channel_e, 'o i -> o 1', 'sum')
        w_dir_e = w_dir_e / reduce(w_dir_e, 'd2 i d1 -> d2 i 1', 'sum')
        w_spatial_e = w_spatial_e / reduce(w_spatial_e, 'i d h w -> i d 1 1',
                                           'sum')

        # combine seperable convolution for speed
        w_conv_d = rearrange(
            torch.einsum('o i, i h w -> o i h w', w_channel_d, w_spatial_d),
            'o i h w -> o i (h w)')
        if self.n_in >= self.n_out:
            w_conv_e = rearrange(
                torch.einsum('o i, p i d, i d h w -> o p i d h w', w_channel_e,
                             w_dir_e, w_spatial_e),
                'o d2 i d1 h w -> (o d2) (i d1) h w')
            return w_conv_d, w_prop_e, w_pow_e, w_conv_e, None
        else:
            w_conv_e = rearrange(
                torch.einsum('p i d, i d h w -> p i d h w', w_dir_e,
                             w_spatial_e), 'd2 i d1 h w -> d2 (i d1) h w')
            return w_conv_d, w_prop_e, w_pow_e, w_conv_e, w_channel_e
Exemple #4
0
    def visualize_weights(self, rows, cols, col):
        w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.prepare_weights()

        w_spatial_d = reduce(w_conv_d, 'o i h w -> i h w', 'sum')
        w_channel_d = reduce(w_conv_d, 'o i h w -> o i', 'sum')

        w_conv_e = rearrange(w_conv_e,
                             '(o d2) (i d1) h w -> o d2 i d1 h w',
                             d1=4,
                             d2=4)
        w_spatial_e = reduce(w_conv_e, 'o d2 i d1 h w -> i d1 h w', 'sum')
        w_channel_e = reduce(w_conv_e, 'o d2 i d1 h w -> o i', 'sum')
        w_dir_e = reduce(w_conv_e, 'o d2 i d1 h w -> d2 i d1', 'sum')

        idx = col
        plt.axis('off')

        idx += cols * self.n_c
        idx += cols

        for c in range(self.n_c):
            for d in range(4):
                if self.kernel_size > 1:
                    ax = plt.subplot(rows, cols, idx)
                    plt.imshow(w_spatial_e[c, d, :, :])
                    plt.xlabel('w', fontsize=min(10, 100 / rows))
                    plt.ylabel('h', fontsize=min(10, 100 / rows))
                    plt.xticks([])
                    plt.yticks([])
                idx += cols

        if self.n_c > 1:
            ax = plt.subplot(rows, cols, idx)
            plt.xlabel('in', fontsize=min(10, 100 / rows))
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            plt.imshow(w_channel_e)
            plt.xticks([])
            plt.yticks([])
            idx += cols

        for c in range(self.n_c):
            ax = plt.subplot(rows, cols, idx)
            plt.imshow(w_dir_e[:, c, :])
            plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                       fontsize=min(10, 100 / rows))
            plt.yticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                       fontsize=min(10, 100 / rows))
            ax.tick_params(length=0)
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            idx += cols

        ax = plt.subplot(rows, cols, idx)
        plt.imshow(w_skip_e[0, :, :, 0, 0])
        for (j, i), label in np.ndenumerate(w_skip_e[0, :, :, 0, 0]):
            ax.text(i,
                    j,
                    '{:.02f}'.format(label),
                    ha='center',
                    va='center',
                    fontsize=min(10, 80 / rows),
                    color='tomato')
        plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'],
                   fontsize=min(10, 100 / rows))
        plt.yticks([])
        if self.n_c > 1:
            plt.ylabel('c', fontsize=min(10, 100 / rows))
        ax.tick_params(length=0)
        idx += cols

        if self.kernel_size > 1:
            for c in range(self.n_c):
                ax = plt.subplot(rows, cols, idx)
                plt.imshow(w_spatial_d[c, :, :])
                plt.xlabel('w', fontsize=min(10, 100 / rows))
                plt.ylabel('h', fontsize=min(10, 100 / rows))
                plt.xticks([])
                plt.yticks([])
                idx += cols

        if self.n_c > 1:
            ax = plt.subplot(rows, cols, idx)
            plt.xlabel('in', fontsize=min(10, 100 / rows))
            plt.ylabel('out', fontsize=min(10, 100 / rows))
            plt.imshow(w_channel_d[:, :])
            plt.xticks([])
            plt.yticks([])
            idx += cols

        ax = plt.subplot(rows, cols, idx)
        plt.imshow(w_skip_d[:, :, 0, 0])
        for (j, i), label in np.ndenumerate(w_skip_d[:, :, 0, 0]):
            ax.text(i,
                    j,
                    '{:.02f}'.format(label),
                    ha='center',
                    va='center',
                    fontsize=min(10, 80 / rows),
                    color='tomato')
        plt.xticks([])
        plt.yticks([])
        if self.n_c > 1:
            plt.xlabel('c', fontsize=min(10, 100 / rows))
        idx += cols
Exemple #5
0
    def forward(self,
                x,
                context=None,
                mask=None,
                context_mask=None,
                tie_attn_dim=None):
        device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(
            context)

        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))

        i, j = q.shape[-2], k.shape[-2]

        # memory compressed attention, to make cross-attention more efficient

        if exists(self.compress_fn):
            assert has_context, 'memory compressed attention only works in the context of cross attention for now'

            ratio = self.compress_ratio
            padding = ratio - (j % ratio)

            if padding < ratio:
                k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value=0),
                           (k, v))

                if exists(context_mask):
                    context_mask = F.pad(context_mask, (0, padding),
                                         value=False)

                k, v = map(lambda t: rearrange(t, 'b n c -> b c n'), (k, v))
                k, v = map(self.compress_fn, (k, v))
                k, v = map(lambda t: rearrange(t, 'b c n -> b n c'), (k, v))

                if exists(context_mask):
                    context_mask = reduce(context_mask.float(),
                                          'b (n r) -> b n',
                                          'sum',
                                          r=ratio)
                    context_mask = context_mask > 0

                j = (j + padding) // ratio

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))

        # for tying row-attention, for MSA axial self-attention

        if exists(tie_attn_dim):
            q, k, v = map(
                lambda t: rearrange(
                    t, '(b r) h n d -> b r h n d', r=tie_attn_dim), (q, k, v))

            # when tying row-attention, one cannot have any masked out tokens
            if exists(mask):
                assert torch.all(
                    mask
                ), 'you cannot have any padding if you are to tie the row attention across MSAs'
                mask = None

            dots = einsum('b r h i d, b r h j d -> b h i j', q,
                          k) * self.scale * (tie_attn_dim**-0.5)
        else:
            dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # masking

        if exists(mask) or exists(context_mask):
            mask = default(mask,
                           lambda: torch.ones(1, i, device=device).bool())
            context_mask = default(
                context_mask, mask) if not has_context else default(
                    context_mask,
                    lambda: torch.ones(1, j, device=device).bool())
            mask_value = -torch.finfo(dots.dtype).max
            mask = mask[:, None, :, None] * context_mask[:, None, None, :]
            dots.masked_fill_(~mask, mask_value)

        # attention

        attn = dots.softmax(dim=-1)
        attn = self.dropout(attn)

        # aggregate

        if exists(tie_attn_dim):
            out = einsum('b h i j, b r h j d -> b r h i d', attn, v)
            out = rearrange(out, 'b r h n d -> (b r) h n d')
        else:
            out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # combine heads and project out

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        return out
def compute_pairwise_losses(
    estimate: torch.Tensor,
    target: torch.Tensor,
    axis: int,
    loss_fn=torch.nn.functional.mse_loss,
):
    """
    The function pit_loss can be more efficient implemented, when the
    loss allows to calculate a pair wise loss. The pair wise losses are
    then written to a matrix (each estimated signal vs each target signal).
    On the matrix with the pair wise losses the function
    `scipy.optimize.linear_sum_assignment` (Hungarian algorithm) can find the
    best permutation.

    The runtime of `scipy.optimize.linear_sum_assignment` does not matter,
    so the runtime complexity decreases from faculty complexity to quadratic
    with respect to the number of speakers.
    For 2 speakers this is slightly slower, but for large numbers of speakers
    (e.g. 7) thiis function is significant faster.

    Limitation:
        Not every loss function can be factorized in pair_wise losses.
        And sometimes it is difficult to implement the pair wise loss
        (See the special implementation in this function for cross_entropy).
        One good point is, that most used loss functions can be factorized.

    Does not support batch dimension. Does not support PackedSequence.

    Args:
        estimate: Padded sequence. The speaker axis is specified with `axis`,
            so the default shape is (T, K, F)
        target: Padded sequence with the same shape as `estimate` (defaults
            to (T, K, F))
        loss_fn: Loss function to apply on each permutation. It must accept two
            arguments (estimate and target) of the same shape that this function
            receives the arguments.
        axis: Speaker axis K. The permutation is applied along this axis. axis=-2
            and an input shape of (T, K, F) corresponds to the old default
            behaviour.

    Examples:
        >>> T, K, F = 4, 2, 5
        >>> estimate, target = torch.ones(T, K, F), torch.zeros(T, K, F)
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 1))
        tensor(1.)

        >>> T, K, F = 4, 2, 5
        >>> estimate, target = torch.ones(T, K, F), torch.zeros(T, F, dtype=torch.int64)
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 1, loss_fn=torch.nn.functional.cross_entropy), reduction='sum')
        tensor(0.6931)
        >>> pit_loss(estimate, target, 1, loss_fn=torch.nn.functional.cross_entropy)
        tensor(0.6931)

        >>> T, K, F = 4, 2, 5
        >>> estimate, target = torch.ones(K, F, T), torch.zeros(K, F, T)
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 0))
        tensor(1.)

        >>> T, K, F = 4, 2, 5
        >>> estimate = torch.stack([torch.ones(F, T), torch.zeros(F, T)])
        >>> target = estimate[(1, 0), :, :]
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=0), return_permutation=True)
        (tensor(0.), array([1, 0]))

        >>> K = 5
        >>> estimate, target = torch.ones(K), torch.zeros(K)
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=0))
        tensor(1.)

        >>> A, B, K, C, F = 4, 5, 3, 100, 128
        >>> estimate, target = torch.ones(A, B, K, C, F), torch.zeros(A, B, K, C, F)
        >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=-3))
        tensor(1.)
    """
    sources = estimate.size()[axis]
    assert sources < 30, f'Are you sure? sources={sources}'
    if loss_fn in [torch.nn.functional.cross_entropy]:
        import einops

        assert axis % estimate.ndimension() == 1, axis
        estimate_shape = list(estimate.shape)
        del estimate_shape[1]
        assert estimate_shape == list(target.shape), (
            f'{estimate.shape} (N, K, ...) does not match {target.shape} (N, ...)'
        )

        assert loss_fn == torch.nn.functional.cross_entropy, loss_fn
        assert axis == 1, axis

        # torch.einsum does not support reduction of ...
        return einops.reduce(torch.einsum(
            'nc...,n...k->n...ck', -torch.nn.LogSoftmax(dim=1)(estimate),
            torch.nn.functional.one_hot(target, num_classes=sources).to(
                estimate.dtype)),
                             'n ... c k -> c k',
                             reduction='mean')

    else:
        assert estimate.size() == target.size(), (
            f'{estimate.size()} != {target.size()}')

        assert estimate.shape == target.shape, (estimate.shape, target.shape)

        indexer_e = [
            slice(None),
        ] * estimate.ndim
        indexer_t = [
            slice(None),
        ] * target.ndim
        pair_wise_loss_matrix = []
        for i in range(sources):
            indexer_e[axis] = i
            for j in range(0, sources):
                indexer_t[axis] = j
                pair_wise_loss_matrix.append(
                    loss_fn(
                        estimate[tuple(indexer_e)],
                        target[tuple(indexer_t)],
                    ))
        return torch.stack(pair_wise_loss_matrix, 0).reshape(sources, sources)
Exemple #7
0
    def forward(self, x_skip, ece_skip, ce_skip, x_pool, ece_pool, ce_pool):
        # x = d*cd, cd
        # d = depth
        # cd = confidence over depth
        # ece = directed smoothness * ce;  dim 2 corresponds to edge directions: /, -, \, |
        # ce = confidence over directed smoothness

        if self.training:
            w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.prepare_weights()
        else:
            w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.weights

        # unpooling d
        # even if there is an edge, it would be difficult to assign the d_pool to one side, so it is unpooled without s_skip
        if self.kernel_size == 2 or not self.scs_unpool_d:
            # no need for smoothness since it would factor out if nothig overlaps
            x_pool = F.conv_transpose2d(x_pool,
                                        w_conv_d,
                                        padding=self.padding,
                                        stride=2)
        else:
            # only use a single s_pool factor as opposed to uSNC
            # if a location is on an edge, each side of the unpooled version will depend on values from their side of the edge
            scs_pool = reduce(ece_pool, 'b c d h w -> b c h w', 'prod')
            x_pool = F.conv_transpose2d(x_pool * scs_pool.repeat(2,1,1,1), w_conv_d, padding=self.padding, stride=2) \
                / (F.conv_transpose2d(scs_pool, w_conv_d, padding=self.padding, stride=2) + self.eps).repeat(2,1,1,1)

        # unpooling e
        # if the pooled data predicts an edge, the skip connection knows where (low e) or where not (high ece at one point, less confidence at another)
        #    => during unpooling of e, favour locations with low ece_skip
        # if the pooled data does not predict an edge, it should not be focused onto edges
        #    => deconv without skip
        # combine both versions with a nconv weighted by the unfocused version

        ece_pool = rearrange(ece_pool, 'b c d h w -> b (c d) h w')
        ce_pool = rearrange(ce_pool, 'b c d h w -> b (c d) h w')
        if self.unfocused_unpool_e and self.focused_unpool_e:
            w_s = 1 / (rearrange(ece_skip, 'b c d h w -> b (c d) h w') +
                       self.eps)
            w_s_sum = F.conv2d(w_s, w_conv_e, padding=self.padding,
                               stride=2) + self.eps
            # divide by w_s_sum first in this deconvolution, then multiply by w_s
            ce_pool_focus = F.conv_transpose2d(
                ce_pool / w_s_sum, w_conv_e, padding=self.padding,
                stride=2) * w_s
            ece_pool_focus = F.conv_transpose2d(
                ece_pool / w_s_sum, w_conv_e, padding=self.padding,
                stride=2) * w_s

            ce_pool_unfocused = F.conv_transpose2d(ce_pool,
                                                   w_conv_e,
                                                   padding=self.padding,
                                                   stride=2)
            ece_pool_unfocused = F.conv_transpose2d(ece_pool,
                                                    w_conv_e,
                                                    padding=self.padding,
                                                    stride=2)

            e_pool_unfocused = (ece_pool_unfocused /
                                (ce_pool_unfocused + self.eps)).detach()
            ce_pool = rearrange(ece_pool_unfocused +
                                (1 - e_pool_unfocused) * ce_pool_focus,
                                'b (c d) h w -> b c d h w',
                                d=4)
            ece_pool = rearrange(e_pool_unfocused * ece_pool_unfocused +
                                 (1 - e_pool_unfocused) * ece_pool_focus,
                                 'b (c d) h w -> b c d h w',
                                 d=4)
        elif self.focused_unpool_e:
            w_s = 1 / (rearrange(ece_skip, 'b c d h w -> b (c d) h w') +
                       self.eps)
            w_s_sum = F.conv2d(w_s, w_conv_e, padding=self.padding,
                               stride=2) + self.eps
            # divide by w_s_sum first in this deconvolution, then multiply by w_s
            ce_pool = rearrange(F.conv_transpose2d(
                ce_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) *
                                w_s,
                                'b (c d) h w -> b c d h w',
                                d=4)
            ece_pool = rearrange(F.conv_transpose2d(
                ece_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) *
                                 w_s,
                                 'b (c d) h w -> b c d h w',
                                 d=4)
        else:
            ce_pool = rearrange(F.conv_transpose2d(ce_pool,
                                                   w_conv_e,
                                                   padding=self.padding,
                                                   stride=2),
                                'b (c d) h w -> b c d h w',
                                d=4)
            ece_pool = rearrange(F.conv_transpose2d(ece_pool,
                                                    w_conv_e,
                                                    padding=self.padding,
                                                    stride=2),
                                 'b (c d) h w -> b c d h w',
                                 d=4)
        s_pool = reduce(ece_pool / (ce_pool + self.eps),
                        'b c d h w -> b c h w', 'prod')

        # combining pool and skip
        # in general, each should have proportionally higher c in areas they are more suited to in terms of distance
        # additionally, skip is prefered around edges because of its higher resolution
        # to determine whether there is an edge, s_pool used
        # it has values anywhere where there is data to interpolate, likely includes less input errors and is less likely to have gaps in edges
        #    => use w_skip, w_pool*s_pool

        w_pool_d = (1 - w_skip_d) * s_pool
        w_sum_d = w_skip_d + w_pool_d + self.eps
        x = (w_skip_d * x_skip + w_pool_d.repeat(2, 1, 1, 1) *
             x_pool) / w_sum_d.repeat(2, 1, 1, 1)

        w_pool_e = (1 - w_skip_e) * s_pool[:, :, None, :, :]
        w_sum_e = w_skip_e + w_pool_e + self.eps
        ce = (w_skip_e * ce_skip + w_pool_e * ce_pool) / w_sum_e
        ece = (w_skip_e * ece_skip + w_pool_e * ece_pool) / w_sum_e

        if ece.requires_grad:
            ece.register_hook(lambda grad: torch.clamp(grad, -1000, 1000))
            ce.register_hook(lambda grad: torch.clamp(grad, -1000, 1000))

        return x, ece, ce
Exemple #8
0
 def max_pool2d_layer(self, x):
     result = reduce(x, 'b c (h h1) (w w1) -> b c h w', 'max', h1=2, w1=2)
     return result
Exemple #9
0
 def test8(x):
     # max-pooling
     y = reduce(x, 'b c (h h1) (w w1) -> b c h w', reduction='max', h1=2, w1=2)
     assert y.shape == (10, 20, 30 // 2, 40 // 2)
     return y
Exemple #10
0
    def forward(self, x, adj, size=None, return_attention_weights=None):
        """
        Args:
            x: Union[Tensor, PairTensor]
            adj: Tensor[2, num_edges] or list of Tensor
            size: Size
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(adj, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        h, c = self.heads, self.out_channels
        # assert (not isinstance(adj, Tensor)) and h == len(adj), 'Number of heads is number of adjacency matrices'

        x_l, x_r, alpha_l, alpha_r, alpha_l_, alpha_r_ = None, None, None, None, None, None

        if isinstance(x, Tensor):
            x_l, x_r = x, None
        else:
            x_l, x_r = x[0], x[1]
        # assert x_l.dim() == 2, 'Static graphs not supported in `HGAConv`.'
        x_l = self.lin_l(x_l)
        if x_l.dim() == 2:
            alpha_l = torch.mm(x_l, self.att_l)
        else:  # x_l is 3D shape, matmul is in batched mode
            alpha_l = torch.matmul(x_l, self.att_l)

        if x_r is not None:
            x_r = self.lin_r(x_r)
            alpha_r = torch.mm(x_r, self.att_r)
            alpha_r_ = torch.mm(x_l, self.att_r)
            alpha_l_ = torch.mm(x_r, self.att_l)
            self.add_self_loops = False
        else:
            if x_l.dim() == 2:
                alpha_r = torch.mm(x_l, self.att_r)
            else:
                alpha_r = torch.matmul(x_l, self.att_r)

        assert x_l is not None
        assert alpha_l is not None

        if self.add_self_loops:
            num_nodes = x_l.shape[-2]
            num_nodes = size[1] if size is not None else num_nodes
            num_nodes = x_r.shape[-2] if x_r is not None else num_nodes
            if isinstance(adj, Tensor):
                adj = self_loop_augment(num_nodes, adj)  # TODO Bug found
            else:
                for i in range(len(adj)):
                    adj[i] = self_loop_augment(num_nodes, adj[i])

        # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
        _x_ = (x_l, x_r) if x_r is not None else x_l
        _alpha_ = (alpha_l, alpha_r)
        alpha_ = (alpha_l_, alpha_r_)
        out = self.propagate(adj,
                             x=_x_,
                             alpha=_alpha_,
                             alpha_=alpha_,
                             size=size)

        alpha = self._alpha
        self._alpha = None

        if isinstance(out,
                      Tensor):  # reshape here is equivalent to concatenation
            if len(x_l.shape) == 2:
                out = rearrange(out, '(h n) c -> n (h c)', h=h)
            else:
                out = rearrange(out, 't (h n) c -> t n (h c)', h=h)
        else:
            out = (out[0].reshape(-1, h * c), out[1].reshape(-1, h * c))

        if not self.concat:  # calculate mean
            if isinstance(out, Tensor):
                if len(x_l.shape) == 2:
                    out = reduce(out, 'n (h c) -> n c', 'mean', h=h)
                else:
                    out = reduce(out, 't n (h c) -> t n c', 'mean', h=h)
            else:
                out = (out[0].mean(dim=1), out[1].mean(dim=1))

        if self.bias is not None:
            if isinstance(out, Tensor):
                out += self.bias
            else:
                out = (out[0] + self.bias, out[1] + self.bias)
        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            return out, (adj, alpha)
        else:
            return out
Exemple #11
0
    def forward(
        self,
        seq,
        msa=None,
        mask=None,
        msa_mask=None,
        templates_seq=None,
        templates_dist=None,
        templates_mask=None,
        templates_coors=None,
        templates_sidechains=None,
        embedds=None,
    ):
        n, device = seq.shape[1], seq.device
        n_range = torch.arange(n, device=device)

        # unpack (AA_code, atom_pos)

        if isinstance(seq, (list, tuple)):
            seq, seq_pos = seq

        # embed main sequence

        x = self.token_emb(seq)

        # outer sum

        x = rearrange(x, 'b i d -> b () i () d') + rearrange(
            x, 'b j d-> b () () j d')  # create pair-wise residue embeds
        x_mask = rearrange(mask, 'b i -> b () i ()') + rearrange(
            mask, 'b j -> b () () j') if exists(mask) else None

        # axial positional embedding

        pos_emb = rearrange(self.pos_emb(n_range),
                            'i d -> () i () d') + rearrange(
                                self.pos_emb_ax(n_range), 'j d -> () () j d')
        x += pos_emb

        # embed multiple sequence alignment (msa)

        m = None
        msa_shape = None
        if exists(msa):
            m = self.token_emb(msa)
            m += self.msa_pos_emb(torch.arange(msa.shape[-1],
                                               device=device))[None, None, ...]
            m += self.msa_num_pos_emb(torch.arange(msa.shape[1],
                                                   device=device))[None, :,
                                                                   None, :]

            msa_shape = m.shape
            m = rearrange(m, 'b m n d -> b (m n) d')

        elif exists(embedds):
            m = self.embedd_project(embedds)
            m = rearrange(m, 'b i d -> b i () d') + rearrange(
                m, 'b j d -> b () j d')
            m = rearrange(m, 'b m n d -> b (m n) d')

        if exists(msa_mask):
            msa_mask = rearrange(msa_mask, 'b m n -> b (m n)')

        # embed templates, if present

        if exists(templates_seq):
            assert exists(
                templates_coors
            ), 'template residue coordinates must be supplied `templates_coors`'
            _, num_templates, *_ = templates_seq.shape

            if not exists(templates_dist):
                templates_dist = get_bucketed_distance_matrix(
                    templates_coors, templates_mask,
                    constants.DISTOGRAM_BUCKETS)

            # embed template

            t_seq = self.token_emb(templates_seq)

            # if sidechain information is present
            # color the residue embeddings with the sidechain type 1 features
            # todo (make efficient)

            if exists(templates_sidechains):
                if self.use_se3_transformer:
                    t_seq = self.template_sidechain_emb(t_seq,
                                                        templates_sidechains,
                                                        templates_coors,
                                                        mask=templates_mask)
                else:
                    shape = t_seq.shape
                    t_seq = rearrange(t_seq, 'b t n d -> (b t) n d')
                    templates_coors = rearrange(templates_coors,
                                                'b t n c -> (b t) n c')
                    en_mask = rearrange(templates_mask, 'b t n -> (b t) n')

                    t_seq, _ = self.template_sidechain_emb(t_seq,
                                                           templates_coors,
                                                           mask=en_mask)

                    t_seq = t_seq.reshape(*shape)

            # embed template distances

            t_dist = self.template_dist_emb(templates_dist)

            t_seq = rearrange(t_seq, 'b t i d -> b t i () d') + rearrange(
                t_seq, 'b t j d -> b t () j d')
            t = t_seq + t_dist

            # template pos emb

            template_num_pos_emb = self.template_num_pos_emb(
                torch.arange(num_templates, device=device))
            t += rearrange(template_num_pos_emb, 't d-> () t () () d')

            pos_emb = rearrange(self.template_pos_emb(n_range),
                                'i d -> () () i () d') + rearrange(
                                    self.template_pos_emb_ax(n_range),
                                    'j d -> () () () j d')
            t += pos_emb

            assert t.shape[-2:] == x.shape[-2:]

            x = torch.cat((x, t), dim=1)

            if exists(templates_mask):
                t_mask = rearrange(templates_mask,
                                   'b t i -> b t i ()') * rearrange(
                                       templates_mask, 'b t j -> b t () j')
                x_mask = torch.cat((x_mask, t_mask), dim=1)

        # flatten

        seq_shape = x.shape
        x = rearrange(x, 'b t i j d -> b (t i j) d')
        x_mask = rearrange(x_mask,
                           'b t i j -> b (t i j)') if exists(mask) else None

        # trunk

        x, m = self.net(x,
                        m,
                        seq_shape,
                        msa_shape,
                        mask=x_mask,
                        msa_mask=msa_mask)

        # remove templates, if present

        x = x.view(seq_shape)
        x = x[:, 0]

        # embeds to distogram

        trunk_embeds = (x +
                        rearrange(x, 'b i j d -> b j i d')) * 0.5  # symmetrize
        distogram_logits = self.to_distogram_logits(trunk_embeds)

        if not self.predict_coords:
            return distogram_logits

        # prepare mask for backbone coordinates

        assert self.num_backbone_atoms > 1, 'must constitute to at least 3 atomic coordinates for backbone'

        if self.num_backbone_atoms >= 3:
            N_mask, CA_mask, C_mask = scn_backbone_mask(
                seq, boolean=True, n_aa=self.num_backbone_atoms)

            cloud_mask = scn_cloud_mask(seq, boolean=True)
            flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
            chain_mask = (mask.unsqueeze(-1) * cloud_mask)
            flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

            mask = rearrange(chain_mask[:, :, :self.num_backbone_atoms],
                             'b l c -> b (l c)')

        # structural refinement

        distances, weights = center_distogram_torch(distogram_logits)
        coords_3d, _ = MDScaling(distances,
                                 weights=weights,
                                 iters=self.mds_iters,
                                 fix_mirror=True,
                                 N_mask=N_mask,
                                 CA_mask=CA_mask,
                                 C_mask=C_mask)
        coords = rearrange(coords_3d, 'b c n -> b n c')
        # will init all sidechain coords to cbeta if present else c_alpha
        coords = sidechain_container(coords,
                                     n_aa=self.num_backbone_atoms,
                                     cloud_mask=cloud_mask)
        coords = rearrange(coords, 'b n l d -> b (n l) d')
        atom_tokens = scn_atom_embedd(seq)  # not used for now, but could be

        trunk_embeds = self.trunk_to_structure_dim(trunk_embeds)
        x = reduce(trunk_embeds, 'b i j d -> b i d', 'mean')
        x += self.structure_module_embeds(seq)
        x = repeat(x, 'b n d -> b n l d', l=cloud_mask.shape[-1])
        x += self.atom_tokens_embed(atom_tokens)
        x = rearrange(x, 'b n l d -> b (n l) d')

        original_dtype = coords.dtype
        x, coords = map(lambda t: t.double(), (x, coords))

        with torch_default_dtype(torch.float64):
            for _ in range(self.structure_module_refinement_iters):
                x, coords = self.structure_module(x,
                                                  coords,
                                                  mask=flat_chain_mask)

        coords.type(original_dtype)
        return coords
Exemple #12
0
def data_init(tensor_p: 'path', labels_p: 'path'):
    tensor = np.load(tensor_p).astype('float32') / 255
    labels = np.load(labels_p)
    # half the size
    tensor = reduce(tensor, 'b (h h2) (w w2) c -> b h w c', 'max', h2=2, w2=2)
    return (tensor, labels)