def visualize_weights(self, rows, cols, col): w_conv_d, w_prop_e, w_pow_e, w_conv_e, w_channel_e = self.prepare_weights( ) w_conv_d = rearrange(w_conv_d, 'o i (h w) -> o i h w', h=self.kernel_size) w_spatial_d = reduce(w_conv_d, 'o i h w -> i h w', 'sum') w_channel_d = reduce(w_conv_d, 'o i h w -> o i', 'sum') w_conv_e = rearrange(w_conv_e, '(o d2) (i d1) h w -> o d2 i d1 h w', d1=4, d2=4) w_spatial_e = reduce(w_conv_e, 'o d2 i d1 h w -> i d1 h w', 'sum') if self.n_in >= self.n_out: w_channel_e = reduce(w_conv_e, 'o d2 i d1 h w -> o i', 'sum') w_dir_e = reduce(w_conv_e, 'o d2 i d1 h w -> d2 i d1', 'sum') idx = col for c in range(self.n_in): ax = plt.subplot(rows, cols, idx) plt.imshow(w_pow_e[c, :, :]) for (j, i), label in np.ndenumerate(w_pow_e[c, :, :]): ax.text(i, j, '{:.02f}'.format(label), ha='center', va='center', fontsize=min(10, 60 / rows), color='tomato') plt.xticks( [0, 1, 2, 3], [r'$^a / _b$', r'$\frac{a}{b}$', r'$_b \backslash ^a$', 'b|a'], fontsize=min(10, 100 / rows)) plt.yticks([0, 1], [r'$\frac{a}{b}$', r'$\frac{b}{a}$'], fontsize=min(10, 100 / rows)) ax.tick_params(length=0) idx += cols for c in range(self.n_in, self.n_out): idx += cols ax = plt.subplot(rows, cols, idx) plt.imshow(w_prop_e[0, :, :, 0, 0]) for (j, i), label in np.ndenumerate(w_prop_e[0, :, :, 0, 0]): ax.text(i, j, '{:.02f}'.format(label), ha='center', va='center', fontsize=min(10, 80 / rows), color='tomato') plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) plt.yticks([]) if self.n_in > 1: plt.ylabel('c', fontsize=min(10, 100 / rows)) ax.tick_params(length=0) idx += cols for c in range(self.n_in): for d in range(4): if self.kernel_size > 1: ax = plt.subplot(rows, cols, idx) plt.imshow(w_spatial_e[c, d, :, :]) plt.xlabel('w', fontsize=min(10, 100 / rows)) plt.ylabel('h', fontsize=min(10, 100 / rows)) plt.xticks([]) plt.yticks([]) idx += cols idx += max(0, self.n_out - self.n_in) * 4 * cols if self.n_in > 1: ax = plt.subplot(rows, cols, idx) plt.xlabel('in', fontsize=min(10, 100 / rows)) plt.ylabel('out', fontsize=min(10, 100 / rows)) plt.imshow(w_channel_e) plt.xticks([]) plt.yticks([]) idx += cols elif self.n_out > 1: idx += cols for c in range(self.n_in): ax = plt.subplot(rows, cols, idx) plt.imshow(w_dir_e[:, c, :]) plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) plt.yticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) ax.tick_params(length=0) plt.ylabel('out', fontsize=min(10, 100 / rows)) idx += cols idx += max(0, self.n_out - self.n_in) * cols idx += cols if self.kernel_size > 1: for c in range(self.n_in): ax = plt.subplot(rows, cols, idx) plt.imshow(w_spatial_d[c, :, :]) plt.xlabel('w', fontsize=min(10, 100 / rows)) plt.ylabel('h', fontsize=min(10, 100 / rows)) plt.xticks([]) plt.yticks([]) idx += cols idx += max(0, self.n_out - self.n_in) * cols else: idx += max(self.n_in, self.n_out) * cols if self.n_in > 1: ax = plt.subplot(rows, cols, idx) plt.xlabel('in', fontsize=min(10, 100 / rows)) plt.ylabel('out', fontsize=min(10, 100 / rows)) plt.imshow(w_channel_d[:, :]) plt.xticks([]) plt.yticks([]) idx += cols elif self.n_out > 1: idx += cols
def forward(self, x, mask=None, return_attn=False): b, n, _, h, m, iters, eps = *x.shape, self.heads, self.m, self.pinv_iterations, self.eps # pad so that sequence can be evenly divided into m landmarks remainder = n % m if remainder > 0: padding = m - (n % m) x = F.pad(x, (0, 0, 0, padding), value=0) if exists(mask): mask = F.pad(mask, (0, padding), value=False) # derive query, keys, values q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # set masked positions to 0 in queries, keys, values if exists(mask): mask = rearrange(mask, 'b n -> b () n') q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) q *= self.scale # generate landmarks by sum reduction, and then calculate mean using the mask l = ceil(n / m) landmark_einops_eq = '... (n l) d -> ... n d' q_landmarks = reduce(q, landmark_einops_eq, 'sum', l=l) k_landmarks = reduce(k, landmark_einops_eq, 'sum', l=l) # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean divisor = l if exists(mask): mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l=l) divisor = mask_landmarks_sum[..., None] + eps mask_landmarks = mask_landmarks_sum > 0 # masked mean (if mask exists) q_landmarks /= divisor k_landmarks /= divisor # similarities einops_eq = '... i d, ... j d -> ... i j' sim1 = einsum(einops_eq, q, k_landmarks) sim2 = einsum(einops_eq, q_landmarks, k_landmarks) sim3 = einsum(einops_eq, q_landmarks, k) # masking if exists(mask): mask_value = -torch.finfo(q.dtype).max sim1.masked_fill_( ~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) sim2.masked_fill_( ~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) sim3.masked_fill_( ~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) # eq (15) in the paper attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) attn2_inv = moore_penrose_iter_pinv(attn2, iters) attn = attn1 @ attn2_inv @ attn3 # aggregate out = einsum('... i j, ... j d -> ... i d', attn, v) # add depth-wise conv residual of values if self.residual: out += self.res_conv(v) # merge and combine heads out = rearrange(out, 'b h n d -> b n (h d)', h=h) out = self.to_out(out) out = out[:, :n] if return_attn: return out, attn return out
def prepare_weights(self): # enforce limits w_channel_d = F.softplus(self.w_channel_d) w_spatial_d = F.softplus(self.w_spatial_d) w_prop_e = torch.sigmoid(self.w_prop_e) w_pow_e = F.softplus(self.w_pow_e) w_dir_e = F.softplus(self.w_dir_e) w_channel_e = F.softplus(self.w_channel_e) if self.symmentric: w_spatial_e_0 = F.softplus(self.w_spatial_e_0) w_spatial_e_1 = F.softplus(self.w_spatial_e_1) w_spatial_e_3 = F.softplus(self.w_spatial_e_3) else: w_spatial_e = F.softplus(self.w_spatial_e) if self.no_prop: w_prop_e = torch.zeros_like(w_prop_e) # enforce symmetry by weight sharing if self.symmentric: w_spatial_d = torch.cat( (w_spatial_d, w_spatial_d[:, :, :self.kernel_size // 2].flip(dims=(2, ))), dim=2) #0 => / #1 => - #2 => \ #3 => | # 1, 3 are symmetric; 2 is a mirror of 0 w_spatial_e = torch.stack(( w_spatial_e_0, torch.cat( (w_spatial_e_1, w_spatial_e_1[:, :, :-1].flip(dims=(2, ))), dim=2), w_spatial_e_0.flip(dims=(2, )), torch.cat( (w_spatial_e_3, w_spatial_e_3[:, :, :-1].flip(dims=(2, ))), dim=2)), dim=1) # connect directions to each other; with connections from or to 0 and 2 sharing the same weights w_dir_e = w_dir_e.unbind(1) w_dir_e = torch.stack( (torch.stack( (w_dir_e[0], w_dir_e[1], w_dir_e[2], w_dir_e[3]), dim=1), torch.stack( (w_dir_e[4], w_dir_e[5], w_dir_e[4], w_dir_e[6]), dim=1), torch.stack( (w_dir_e[2], w_dir_e[1], w_dir_e[0], w_dir_e[3]), dim=1), torch.stack( (w_dir_e[7], w_dir_e[8], w_dir_e[7], w_dir_e[9]), dim=1)), dim=0) # 1 w_pow_e per side; 0 and 2 are mirrors; 3 has symmetric sides w_pow_e = w_pow_e.unbind(1) w_pow_e = torch.stack((torch.stack( (w_pow_e[0], w_pow_e[1]), dim=1), torch.stack( (w_pow_e[2], w_pow_e[3]), dim=1), torch.stack( (w_pow_e[1], w_pow_e[0]), dim=1), torch.stack((w_pow_e[4], w_pow_e[4]), dim=1)), dim=2) w_prop_e = torch.cat((w_prop_e[:, :, 1, None, :, :], w_prop_e), dim=2) # normalize by output channel # technically not needed for d, but here for consistency w_channel_d = w_channel_d / reduce(w_channel_d, 'o i -> o 1', 'sum') w_spatial_d = w_spatial_d / reduce(w_spatial_d, 'i h w -> i 1 1', 'sum') w_channel_e = w_channel_e / reduce(w_channel_e, 'o i -> o 1', 'sum') w_dir_e = w_dir_e / reduce(w_dir_e, 'd2 i d1 -> d2 i 1', 'sum') w_spatial_e = w_spatial_e / reduce(w_spatial_e, 'i d h w -> i d 1 1', 'sum') # combine seperable convolution for speed w_conv_d = rearrange( torch.einsum('o i, i h w -> o i h w', w_channel_d, w_spatial_d), 'o i h w -> o i (h w)') if self.n_in >= self.n_out: w_conv_e = rearrange( torch.einsum('o i, p i d, i d h w -> o p i d h w', w_channel_e, w_dir_e, w_spatial_e), 'o d2 i d1 h w -> (o d2) (i d1) h w') return w_conv_d, w_prop_e, w_pow_e, w_conv_e, None else: w_conv_e = rearrange( torch.einsum('p i d, i d h w -> p i d h w', w_dir_e, w_spatial_e), 'd2 i d1 h w -> d2 (i d1) h w') return w_conv_d, w_prop_e, w_pow_e, w_conv_e, w_channel_e
def visualize_weights(self, rows, cols, col): w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.prepare_weights() w_spatial_d = reduce(w_conv_d, 'o i h w -> i h w', 'sum') w_channel_d = reduce(w_conv_d, 'o i h w -> o i', 'sum') w_conv_e = rearrange(w_conv_e, '(o d2) (i d1) h w -> o d2 i d1 h w', d1=4, d2=4) w_spatial_e = reduce(w_conv_e, 'o d2 i d1 h w -> i d1 h w', 'sum') w_channel_e = reduce(w_conv_e, 'o d2 i d1 h w -> o i', 'sum') w_dir_e = reduce(w_conv_e, 'o d2 i d1 h w -> d2 i d1', 'sum') idx = col plt.axis('off') idx += cols * self.n_c idx += cols for c in range(self.n_c): for d in range(4): if self.kernel_size > 1: ax = plt.subplot(rows, cols, idx) plt.imshow(w_spatial_e[c, d, :, :]) plt.xlabel('w', fontsize=min(10, 100 / rows)) plt.ylabel('h', fontsize=min(10, 100 / rows)) plt.xticks([]) plt.yticks([]) idx += cols if self.n_c > 1: ax = plt.subplot(rows, cols, idx) plt.xlabel('in', fontsize=min(10, 100 / rows)) plt.ylabel('out', fontsize=min(10, 100 / rows)) plt.imshow(w_channel_e) plt.xticks([]) plt.yticks([]) idx += cols for c in range(self.n_c): ax = plt.subplot(rows, cols, idx) plt.imshow(w_dir_e[:, c, :]) plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) plt.yticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) ax.tick_params(length=0) plt.ylabel('out', fontsize=min(10, 100 / rows)) idx += cols ax = plt.subplot(rows, cols, idx) plt.imshow(w_skip_e[0, :, :, 0, 0]) for (j, i), label in np.ndenumerate(w_skip_e[0, :, :, 0, 0]): ax.text(i, j, '{:.02f}'.format(label), ha='center', va='center', fontsize=min(10, 80 / rows), color='tomato') plt.xticks([0, 1, 2, 3], ['/', '-', '\\', '|'], fontsize=min(10, 100 / rows)) plt.yticks([]) if self.n_c > 1: plt.ylabel('c', fontsize=min(10, 100 / rows)) ax.tick_params(length=0) idx += cols if self.kernel_size > 1: for c in range(self.n_c): ax = plt.subplot(rows, cols, idx) plt.imshow(w_spatial_d[c, :, :]) plt.xlabel('w', fontsize=min(10, 100 / rows)) plt.ylabel('h', fontsize=min(10, 100 / rows)) plt.xticks([]) plt.yticks([]) idx += cols if self.n_c > 1: ax = plt.subplot(rows, cols, idx) plt.xlabel('in', fontsize=min(10, 100 / rows)) plt.ylabel('out', fontsize=min(10, 100 / rows)) plt.imshow(w_channel_d[:, :]) plt.xticks([]) plt.yticks([]) idx += cols ax = plt.subplot(rows, cols, idx) plt.imshow(w_skip_d[:, :, 0, 0]) for (j, i), label in np.ndenumerate(w_skip_d[:, :, 0, 0]): ax.text(i, j, '{:.02f}'.format(label), ha='center', va='center', fontsize=min(10, 80 / rows), color='tomato') plt.xticks([]) plt.yticks([]) if self.n_c > 1: plt.xlabel('c', fontsize=min(10, 100 / rows)) idx += cols
def forward(self, x, context=None, mask=None, context_mask=None, tie_attn_dim=None): device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists( context) context = default(context, x) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) i, j = q.shape[-2], k.shape[-2] # memory compressed attention, to make cross-attention more efficient if exists(self.compress_fn): assert has_context, 'memory compressed attention only works in the context of cross attention for now' ratio = self.compress_ratio padding = ratio - (j % ratio) if padding < ratio: k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value=0), (k, v)) if exists(context_mask): context_mask = F.pad(context_mask, (0, padding), value=False) k, v = map(lambda t: rearrange(t, 'b n c -> b c n'), (k, v)) k, v = map(self.compress_fn, (k, v)) k, v = map(lambda t: rearrange(t, 'b c n -> b n c'), (k, v)) if exists(context_mask): context_mask = reduce(context_mask.float(), 'b (n r) -> b n', 'sum', r=ratio) context_mask = context_mask > 0 j = (j + padding) // ratio q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # for tying row-attention, for MSA axial self-attention if exists(tie_attn_dim): q, k, v = map( lambda t: rearrange( t, '(b r) h n d -> b r h n d', r=tie_attn_dim), (q, k, v)) # when tying row-attention, one cannot have any masked out tokens if exists(mask): assert torch.all( mask ), 'you cannot have any padding if you are to tie the row attention across MSAs' mask = None dots = einsum('b r h i d, b r h j d -> b h i j', q, k) * self.scale * (tie_attn_dim**-0.5) else: dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking if exists(mask) or exists(context_mask): mask = default(mask, lambda: torch.ones(1, i, device=device).bool()) context_mask = default( context_mask, mask) if not has_context else default( context_mask, lambda: torch.ones(1, j, device=device).bool()) mask_value = -torch.finfo(dots.dtype).max mask = mask[:, None, :, None] * context_mask[:, None, None, :] dots.masked_fill_(~mask, mask_value) # attention attn = dots.softmax(dim=-1) attn = self.dropout(attn) # aggregate if exists(tie_attn_dim): out = einsum('b h i j, b r h j d -> b r h i d', attn, v) out = rearrange(out, 'b r h n d -> (b r) h n d') else: out = einsum('b h i j, b h j d -> b h i d', attn, v) # combine heads and project out out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out
def compute_pairwise_losses( estimate: torch.Tensor, target: torch.Tensor, axis: int, loss_fn=torch.nn.functional.mse_loss, ): """ The function pit_loss can be more efficient implemented, when the loss allows to calculate a pair wise loss. The pair wise losses are then written to a matrix (each estimated signal vs each target signal). On the matrix with the pair wise losses the function `scipy.optimize.linear_sum_assignment` (Hungarian algorithm) can find the best permutation. The runtime of `scipy.optimize.linear_sum_assignment` does not matter, so the runtime complexity decreases from faculty complexity to quadratic with respect to the number of speakers. For 2 speakers this is slightly slower, but for large numbers of speakers (e.g. 7) thiis function is significant faster. Limitation: Not every loss function can be factorized in pair_wise losses. And sometimes it is difficult to implement the pair wise loss (See the special implementation in this function for cross_entropy). One good point is, that most used loss functions can be factorized. Does not support batch dimension. Does not support PackedSequence. Args: estimate: Padded sequence. The speaker axis is specified with `axis`, so the default shape is (T, K, F) target: Padded sequence with the same shape as `estimate` (defaults to (T, K, F)) loss_fn: Loss function to apply on each permutation. It must accept two arguments (estimate and target) of the same shape that this function receives the arguments. axis: Speaker axis K. The permutation is applied along this axis. axis=-2 and an input shape of (T, K, F) corresponds to the old default behaviour. Examples: >>> T, K, F = 4, 2, 5 >>> estimate, target = torch.ones(T, K, F), torch.zeros(T, K, F) >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 1)) tensor(1.) >>> T, K, F = 4, 2, 5 >>> estimate, target = torch.ones(T, K, F), torch.zeros(T, F, dtype=torch.int64) >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 1, loss_fn=torch.nn.functional.cross_entropy), reduction='sum') tensor(0.6931) >>> pit_loss(estimate, target, 1, loss_fn=torch.nn.functional.cross_entropy) tensor(0.6931) >>> T, K, F = 4, 2, 5 >>> estimate, target = torch.ones(K, F, T), torch.zeros(K, F, T) >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, 0)) tensor(1.) >>> T, K, F = 4, 2, 5 >>> estimate = torch.stack([torch.ones(F, T), torch.zeros(F, T)]) >>> target = estimate[(1, 0), :, :] >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=0), return_permutation=True) (tensor(0.), array([1, 0])) >>> K = 5 >>> estimate, target = torch.ones(K), torch.zeros(K) >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=0)) tensor(1.) >>> A, B, K, C, F = 4, 5, 3, 100, 128 >>> estimate, target = torch.ones(A, B, K, C, F), torch.zeros(A, B, K, C, F) >>> pit_loss_from_loss_matrix(compute_pairwise_losses(estimate, target, axis=-3)) tensor(1.) """ sources = estimate.size()[axis] assert sources < 30, f'Are you sure? sources={sources}' if loss_fn in [torch.nn.functional.cross_entropy]: import einops assert axis % estimate.ndimension() == 1, axis estimate_shape = list(estimate.shape) del estimate_shape[1] assert estimate_shape == list(target.shape), ( f'{estimate.shape} (N, K, ...) does not match {target.shape} (N, ...)' ) assert loss_fn == torch.nn.functional.cross_entropy, loss_fn assert axis == 1, axis # torch.einsum does not support reduction of ... return einops.reduce(torch.einsum( 'nc...,n...k->n...ck', -torch.nn.LogSoftmax(dim=1)(estimate), torch.nn.functional.one_hot(target, num_classes=sources).to( estimate.dtype)), 'n ... c k -> c k', reduction='mean') else: assert estimate.size() == target.size(), ( f'{estimate.size()} != {target.size()}') assert estimate.shape == target.shape, (estimate.shape, target.shape) indexer_e = [ slice(None), ] * estimate.ndim indexer_t = [ slice(None), ] * target.ndim pair_wise_loss_matrix = [] for i in range(sources): indexer_e[axis] = i for j in range(0, sources): indexer_t[axis] = j pair_wise_loss_matrix.append( loss_fn( estimate[tuple(indexer_e)], target[tuple(indexer_t)], )) return torch.stack(pair_wise_loss_matrix, 0).reshape(sources, sources)
def forward(self, x_skip, ece_skip, ce_skip, x_pool, ece_pool, ce_pool): # x = d*cd, cd # d = depth # cd = confidence over depth # ece = directed smoothness * ce; dim 2 corresponds to edge directions: /, -, \, | # ce = confidence over directed smoothness if self.training: w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.prepare_weights() else: w_conv_d, w_conv_e, w_skip_d, w_skip_e = self.weights # unpooling d # even if there is an edge, it would be difficult to assign the d_pool to one side, so it is unpooled without s_skip if self.kernel_size == 2 or not self.scs_unpool_d: # no need for smoothness since it would factor out if nothig overlaps x_pool = F.conv_transpose2d(x_pool, w_conv_d, padding=self.padding, stride=2) else: # only use a single s_pool factor as opposed to uSNC # if a location is on an edge, each side of the unpooled version will depend on values from their side of the edge scs_pool = reduce(ece_pool, 'b c d h w -> b c h w', 'prod') x_pool = F.conv_transpose2d(x_pool * scs_pool.repeat(2,1,1,1), w_conv_d, padding=self.padding, stride=2) \ / (F.conv_transpose2d(scs_pool, w_conv_d, padding=self.padding, stride=2) + self.eps).repeat(2,1,1,1) # unpooling e # if the pooled data predicts an edge, the skip connection knows where (low e) or where not (high ece at one point, less confidence at another) # => during unpooling of e, favour locations with low ece_skip # if the pooled data does not predict an edge, it should not be focused onto edges # => deconv without skip # combine both versions with a nconv weighted by the unfocused version ece_pool = rearrange(ece_pool, 'b c d h w -> b (c d) h w') ce_pool = rearrange(ce_pool, 'b c d h w -> b (c d) h w') if self.unfocused_unpool_e and self.focused_unpool_e: w_s = 1 / (rearrange(ece_skip, 'b c d h w -> b (c d) h w') + self.eps) w_s_sum = F.conv2d(w_s, w_conv_e, padding=self.padding, stride=2) + self.eps # divide by w_s_sum first in this deconvolution, then multiply by w_s ce_pool_focus = F.conv_transpose2d( ce_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) * w_s ece_pool_focus = F.conv_transpose2d( ece_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) * w_s ce_pool_unfocused = F.conv_transpose2d(ce_pool, w_conv_e, padding=self.padding, stride=2) ece_pool_unfocused = F.conv_transpose2d(ece_pool, w_conv_e, padding=self.padding, stride=2) e_pool_unfocused = (ece_pool_unfocused / (ce_pool_unfocused + self.eps)).detach() ce_pool = rearrange(ece_pool_unfocused + (1 - e_pool_unfocused) * ce_pool_focus, 'b (c d) h w -> b c d h w', d=4) ece_pool = rearrange(e_pool_unfocused * ece_pool_unfocused + (1 - e_pool_unfocused) * ece_pool_focus, 'b (c d) h w -> b c d h w', d=4) elif self.focused_unpool_e: w_s = 1 / (rearrange(ece_skip, 'b c d h w -> b (c d) h w') + self.eps) w_s_sum = F.conv2d(w_s, w_conv_e, padding=self.padding, stride=2) + self.eps # divide by w_s_sum first in this deconvolution, then multiply by w_s ce_pool = rearrange(F.conv_transpose2d( ce_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) * w_s, 'b (c d) h w -> b c d h w', d=4) ece_pool = rearrange(F.conv_transpose2d( ece_pool / w_s_sum, w_conv_e, padding=self.padding, stride=2) * w_s, 'b (c d) h w -> b c d h w', d=4) else: ce_pool = rearrange(F.conv_transpose2d(ce_pool, w_conv_e, padding=self.padding, stride=2), 'b (c d) h w -> b c d h w', d=4) ece_pool = rearrange(F.conv_transpose2d(ece_pool, w_conv_e, padding=self.padding, stride=2), 'b (c d) h w -> b c d h w', d=4) s_pool = reduce(ece_pool / (ce_pool + self.eps), 'b c d h w -> b c h w', 'prod') # combining pool and skip # in general, each should have proportionally higher c in areas they are more suited to in terms of distance # additionally, skip is prefered around edges because of its higher resolution # to determine whether there is an edge, s_pool used # it has values anywhere where there is data to interpolate, likely includes less input errors and is less likely to have gaps in edges # => use w_skip, w_pool*s_pool w_pool_d = (1 - w_skip_d) * s_pool w_sum_d = w_skip_d + w_pool_d + self.eps x = (w_skip_d * x_skip + w_pool_d.repeat(2, 1, 1, 1) * x_pool) / w_sum_d.repeat(2, 1, 1, 1) w_pool_e = (1 - w_skip_e) * s_pool[:, :, None, :, :] w_sum_e = w_skip_e + w_pool_e + self.eps ce = (w_skip_e * ce_skip + w_pool_e * ce_pool) / w_sum_e ece = (w_skip_e * ece_skip + w_pool_e * ece_pool) / w_sum_e if ece.requires_grad: ece.register_hook(lambda grad: torch.clamp(grad, -1000, 1000)) ce.register_hook(lambda grad: torch.clamp(grad, -1000, 1000)) return x, ece, ce
def max_pool2d_layer(self, x): result = reduce(x, 'b c (h h1) (w w1) -> b c h w', 'max', h1=2, w1=2) return result
def test8(x): # max-pooling y = reduce(x, 'b c (h h1) (w w1) -> b c h w', reduction='max', h1=2, w1=2) assert y.shape == (10, 20, 30 // 2, 40 // 2) return y
def forward(self, x, adj, size=None, return_attention_weights=None): """ Args: x: Union[Tensor, PairTensor] adj: Tensor[2, num_edges] or list of Tensor size: Size return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(adj, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ h, c = self.heads, self.out_channels # assert (not isinstance(adj, Tensor)) and h == len(adj), 'Number of heads is number of adjacency matrices' x_l, x_r, alpha_l, alpha_r, alpha_l_, alpha_r_ = None, None, None, None, None, None if isinstance(x, Tensor): x_l, x_r = x, None else: x_l, x_r = x[0], x[1] # assert x_l.dim() == 2, 'Static graphs not supported in `HGAConv`.' x_l = self.lin_l(x_l) if x_l.dim() == 2: alpha_l = torch.mm(x_l, self.att_l) else: # x_l is 3D shape, matmul is in batched mode alpha_l = torch.matmul(x_l, self.att_l) if x_r is not None: x_r = self.lin_r(x_r) alpha_r = torch.mm(x_r, self.att_r) alpha_r_ = torch.mm(x_l, self.att_r) alpha_l_ = torch.mm(x_r, self.att_l) self.add_self_loops = False else: if x_l.dim() == 2: alpha_r = torch.mm(x_l, self.att_r) else: alpha_r = torch.matmul(x_l, self.att_r) assert x_l is not None assert alpha_l is not None if self.add_self_loops: num_nodes = x_l.shape[-2] num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.shape[-2] if x_r is not None else num_nodes if isinstance(adj, Tensor): adj = self_loop_augment(num_nodes, adj) # TODO Bug found else: for i in range(len(adj)): adj[i] = self_loop_augment(num_nodes, adj[i]) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) _x_ = (x_l, x_r) if x_r is not None else x_l _alpha_ = (alpha_l, alpha_r) alpha_ = (alpha_l_, alpha_r_) out = self.propagate(adj, x=_x_, alpha=_alpha_, alpha_=alpha_, size=size) alpha = self._alpha self._alpha = None if isinstance(out, Tensor): # reshape here is equivalent to concatenation if len(x_l.shape) == 2: out = rearrange(out, '(h n) c -> n (h c)', h=h) else: out = rearrange(out, 't (h n) c -> t n (h c)', h=h) else: out = (out[0].reshape(-1, h * c), out[1].reshape(-1, h * c)) if not self.concat: # calculate mean if isinstance(out, Tensor): if len(x_l.shape) == 2: out = reduce(out, 'n (h c) -> n c', 'mean', h=h) else: out = reduce(out, 't n (h c) -> t n c', 'mean', h=h) else: out = (out[0].mean(dim=1), out[1].mean(dim=1)) if self.bias is not None: if isinstance(out, Tensor): out += self.bias else: out = (out[0] + self.bias, out[1] + self.bias) if isinstance(return_attention_weights, bool): assert alpha is not None return out, (adj, alpha) else: return out
def forward( self, seq, msa=None, mask=None, msa_mask=None, templates_seq=None, templates_dist=None, templates_mask=None, templates_coors=None, templates_sidechains=None, embedds=None, ): n, device = seq.shape[1], seq.device n_range = torch.arange(n, device=device) # unpack (AA_code, atom_pos) if isinstance(seq, (list, tuple)): seq, seq_pos = seq # embed main sequence x = self.token_emb(seq) # outer sum x = rearrange(x, 'b i d -> b () i () d') + rearrange( x, 'b j d-> b () () j d') # create pair-wise residue embeds x_mask = rearrange(mask, 'b i -> b () i ()') + rearrange( mask, 'b j -> b () () j') if exists(mask) else None # axial positional embedding pos_emb = rearrange(self.pos_emb(n_range), 'i d -> () i () d') + rearrange( self.pos_emb_ax(n_range), 'j d -> () () j d') x += pos_emb # embed multiple sequence alignment (msa) m = None msa_shape = None if exists(msa): m = self.token_emb(msa) m += self.msa_pos_emb(torch.arange(msa.shape[-1], device=device))[None, None, ...] m += self.msa_num_pos_emb(torch.arange(msa.shape[1], device=device))[None, :, None, :] msa_shape = m.shape m = rearrange(m, 'b m n d -> b (m n) d') elif exists(embedds): m = self.embedd_project(embedds) m = rearrange(m, 'b i d -> b i () d') + rearrange( m, 'b j d -> b () j d') m = rearrange(m, 'b m n d -> b (m n) d') if exists(msa_mask): msa_mask = rearrange(msa_mask, 'b m n -> b (m n)') # embed templates, if present if exists(templates_seq): assert exists( templates_coors ), 'template residue coordinates must be supplied `templates_coors`' _, num_templates, *_ = templates_seq.shape if not exists(templates_dist): templates_dist = get_bucketed_distance_matrix( templates_coors, templates_mask, constants.DISTOGRAM_BUCKETS) # embed template t_seq = self.token_emb(templates_seq) # if sidechain information is present # color the residue embeddings with the sidechain type 1 features # todo (make efficient) if exists(templates_sidechains): if self.use_se3_transformer: t_seq = self.template_sidechain_emb(t_seq, templates_sidechains, templates_coors, mask=templates_mask) else: shape = t_seq.shape t_seq = rearrange(t_seq, 'b t n d -> (b t) n d') templates_coors = rearrange(templates_coors, 'b t n c -> (b t) n c') en_mask = rearrange(templates_mask, 'b t n -> (b t) n') t_seq, _ = self.template_sidechain_emb(t_seq, templates_coors, mask=en_mask) t_seq = t_seq.reshape(*shape) # embed template distances t_dist = self.template_dist_emb(templates_dist) t_seq = rearrange(t_seq, 'b t i d -> b t i () d') + rearrange( t_seq, 'b t j d -> b t () j d') t = t_seq + t_dist # template pos emb template_num_pos_emb = self.template_num_pos_emb( torch.arange(num_templates, device=device)) t += rearrange(template_num_pos_emb, 't d-> () t () () d') pos_emb = rearrange(self.template_pos_emb(n_range), 'i d -> () () i () d') + rearrange( self.template_pos_emb_ax(n_range), 'j d -> () () () j d') t += pos_emb assert t.shape[-2:] == x.shape[-2:] x = torch.cat((x, t), dim=1) if exists(templates_mask): t_mask = rearrange(templates_mask, 'b t i -> b t i ()') * rearrange( templates_mask, 'b t j -> b t () j') x_mask = torch.cat((x_mask, t_mask), dim=1) # flatten seq_shape = x.shape x = rearrange(x, 'b t i j d -> b (t i j) d') x_mask = rearrange(x_mask, 'b t i j -> b (t i j)') if exists(mask) else None # trunk x, m = self.net(x, m, seq_shape, msa_shape, mask=x_mask, msa_mask=msa_mask) # remove templates, if present x = x.view(seq_shape) x = x[:, 0] # embeds to distogram trunk_embeds = (x + rearrange(x, 'b i j d -> b j i d')) * 0.5 # symmetrize distogram_logits = self.to_distogram_logits(trunk_embeds) if not self.predict_coords: return distogram_logits # prepare mask for backbone coordinates assert self.num_backbone_atoms > 1, 'must constitute to at least 3 atomic coordinates for backbone' if self.num_backbone_atoms >= 3: N_mask, CA_mask, C_mask = scn_backbone_mask( seq, boolean=True, n_aa=self.num_backbone_atoms) cloud_mask = scn_cloud_mask(seq, boolean=True) flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)') chain_mask = (mask.unsqueeze(-1) * cloud_mask) flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)') mask = rearrange(chain_mask[:, :, :self.num_backbone_atoms], 'b l c -> b (l c)') # structural refinement distances, weights = center_distogram_torch(distogram_logits) coords_3d, _ = MDScaling(distances, weights=weights, iters=self.mds_iters, fix_mirror=True, N_mask=N_mask, CA_mask=CA_mask, C_mask=C_mask) coords = rearrange(coords_3d, 'b c n -> b n c') # will init all sidechain coords to cbeta if present else c_alpha coords = sidechain_container(coords, n_aa=self.num_backbone_atoms, cloud_mask=cloud_mask) coords = rearrange(coords, 'b n l d -> b (n l) d') atom_tokens = scn_atom_embedd(seq) # not used for now, but could be trunk_embeds = self.trunk_to_structure_dim(trunk_embeds) x = reduce(trunk_embeds, 'b i j d -> b i d', 'mean') x += self.structure_module_embeds(seq) x = repeat(x, 'b n d -> b n l d', l=cloud_mask.shape[-1]) x += self.atom_tokens_embed(atom_tokens) x = rearrange(x, 'b n l d -> b (n l) d') original_dtype = coords.dtype x, coords = map(lambda t: t.double(), (x, coords)) with torch_default_dtype(torch.float64): for _ in range(self.structure_module_refinement_iters): x, coords = self.structure_module(x, coords, mask=flat_chain_mask) coords.type(original_dtype) return coords
def data_init(tensor_p: 'path', labels_p: 'path'): tensor = np.load(tensor_p).astype('float32') / 255 labels = np.load(labels_p) # half the size tensor = reduce(tensor, 'b (h h2) (w w2) c -> b h w c', 'max', h2=2, w2=2) return (tensor, labels)