def forward(self, x, state_old): states = States(state_old) x = self.bn(x) x = self.act(x) if self.pad is not None: #x = self.pad(x) pad = self.pad if pad[0] > 0: #save last times to use on next iter for left padding x = states.pad_left(x, pad[0], 3) pad = (0, ) + pad[1:] x = torch.nn.functional.pad(x, pad, mode='constant', value=0) x = self.conv(x) return x, states.state
def forward(self, x, state_old=None): B, C, F, _ = get_shape(x) # [B,C,F,T] -> [B,3C,F,T] qkv = self.linear_qkv(x) # [B,3C,F,T] -> [B,F,3C,T] qkv = qkv.transpose(1, 2) # [B,F,3C,T] -> 3*[B,F,C,T] query, key, value = qkv.chunk(3, dim=2) #[B,F,C,T] -> [B,F,C,attn_range+T] states = States(state_old) key = states.pad_left(key, self.attn_range[0], dim=3) value = states.pad_left(value, self.attn_range[0], dim=3) # [B, F, C, T] -> [BF, n_heads, k_channels, T] T_q = get_shape(query)[3] T_kv = get_shape(key)[3] assert C == self.n_heads * self.k_channels query = query.reshape(B * F, self.n_heads, self.k_channels, T_q) key = key.reshape(B * F, self.n_heads, self.k_channels, T_kv) value = value.reshape(B * F, self.n_heads, self.k_channels, T_kv) # transpose [BF, n_heads, k_channels, T] -> [BF, n_heads, T, k_channels] query = query.transpose(2, 3) key = key.transpose(2, 3) value = value.transpose(2, 3) assert self.attn_unroll > 0 q_s = 0 outputs = [] while q_s < T_q: q_e = min(q_s + self.attn_unroll, T_q) kv_s = q_s kv_e = min(q_e + sum(self.attn_range), T_kv) if q_s == 0 and q_e == T_q: query_slice = query else: query_slice = query[:, :, q_s:q_e, :] if kv_s == 0 and kv_e == T_kv: key_slice = key value_slice = value else: key_slice = key[:, :, kv_s:kv_e, :] value_slice = value[:, :, kv_s:kv_e, :] scores = torch.matmul(query_slice, key_slice.transpose( -2, -1)) / math.sqrt(self.k_channels) # mask key scores that out of attn range q_n = q_e - q_s k_mn = 1 + sum(self.attn_range) mask = torch.ones(q_n, k_mn, dtype=scores.dtype, device=scores.device) mask = torch.nn.functional.pad(mask, (0, q_n), mode='constant', value=0) mask = mask.reshape(-1)[:q_n * (k_mn + q_n - 1)] mask = mask.reshape(q_n, (k_mn + q_n - 1)) mask = mask[:, :kv_e - kv_s] mask = mask.unsqueeze(0).unsqueeze(0) scores = scores * mask + -1e4 * (1 - mask) p_attn = torch.nn.functional.softmax(scores, dim=3) # [b, n_h, l_q, l_kv] output = torch.matmul(p_attn, value_slice) outputs.append(output) q_s = q_e output = torch.cat(outputs, 2) if len(outputs) > 1 else outputs[0] # [BF, n_h, T_q, d_k] -> [BF, n_h, d_k, T_q] output = output.transpose(2, 3) # [BF, n_h, d_k, T_q] -> [B, F, C, T_q] output = output.reshape(B, F, C, T_q) # [B, F, C, T] -> [B, C, F, T] output = output.transpose(1, 2) #[B, C, F, T] -> [B, Co, F, T] output = self.conv_o(output) return output, states.state
def forward(self, x, state_old=None): states = States(state_old) tail_size = self.wnd_length - self.hop_length x_padded = states.pad_left(x, tail_size, 1) X = self.encode(x_padded) # [B,2,F,T] z = X #DOWN skips = [] for b in self.blocks_down: z, state = b(z, states.state_old) states.update(state) skips.append(z) z = self.pool(z) #BOTTOM z, state = self.block_bottom(z, states.state_old) states.update(state) #UP for skip, conv_up, block_up in zip(reversed(skips), self.convs_up, self.blocks_up): z = torch.nn.functional.interpolate(z, scale_factor=2, mode='nearest') Fs = get_shape(skip)[-2] Fz = get_shape(z)[-2] if Fz != Fs: z = torch.nn.functional.pad(z, (0, 0, 0, 1), mode='replicate') z = torch.cat([z, skip], 1) pad = self.convs_up_pad if pad[0] > 0: z = states.pad_left(z, pad[0], 3) pad = (0, ) + pad[1:] z = torch.nn.functional.pad(z, pad, mode='constant', value=0) z = conv_up(z) z, state = block_up(z, states.state_old) states.update(state) X = states.pad_left(X, self.ahead, 3, shift_right=True) # [B,2,F,T] -> [B,F,T],[B,F,T] -> Mr, Mi = z[:, 0], z[:, 1] Xr, Xi = X[:, 0], X[:, 1] # mask in complex space Yr = Xr * Mr - Xi * Mi Yi = Xr * Mi + Xi * Mr #[B,F,T] + [B,F,T] -> [B,2,F,T] Y = torch.stack([Yr, Yi], 1) # decode and return only valid samples Y_paded = states.pad_left(Y, self.ahead_ifft, 3) y = self.decode(Y_paded) y = y[:, tail_size:-self.ahead_ifft * self.hop_length] assert not states.state_old return y, Y, states.state