def forward(self, x, state_old): states = States(state_old) for m in self.bd: x_add, state = m(x, states.state_old) states.update(state) x = torch.cat([x, x_add], 1) x, state = self.bd_out(x, states.state_old) states.update(state) if self.sa is None: return x, states.state #apply multihead self attension module #[B,C,F,T] x_add, state = self.sa(x, state_old=states.state_old) states.update(state) x = torch.cat([x, x_add], 1) x, state = self.sa_out(x, states.state_old) states.update(state) return x, states.state
def forward(self, x, state_old=None): states = States(state_old) tail_size = self.wnd_length - self.hop_length x_padded = states.pad_left(x, tail_size, 1) X = self.encode(x_padded) # [B,2,F,T] z = X #DOWN skips = [] for b in self.blocks_down: z, state = b(z, states.state_old) states.update(state) skips.append(z) z = self.pool(z) #BOTTOM z, state = self.block_bottom(z, states.state_old) states.update(state) #UP for skip, conv_up, block_up in zip(reversed(skips), self.convs_up, self.blocks_up): z = torch.nn.functional.interpolate(z, scale_factor=2, mode='nearest') Fs = get_shape(skip)[-2] Fz = get_shape(z)[-2] if Fz != Fs: z = torch.nn.functional.pad(z, (0, 0, 0, 1), mode='replicate') z = torch.cat([z, skip], 1) pad = self.convs_up_pad if pad[0] > 0: z = states.pad_left(z, pad[0], 3) pad = (0, ) + pad[1:] z = torch.nn.functional.pad(z, pad, mode='constant', value=0) z = conv_up(z) z, state = block_up(z, states.state_old) states.update(state) X = states.pad_left(X, self.ahead, 3, shift_right=True) # [B,2,F,T] -> [B,F,T],[B,F,T] -> Mr, Mi = z[:, 0], z[:, 1] Xr, Xi = X[:, 0], X[:, 1] # mask in complex space Yr = Xr * Mr - Xi * Mi Yi = Xr * Mi + Xi * Mr #[B,F,T] + [B,F,T] -> [B,2,F,T] Y = torch.stack([Yr, Yi], 1) # decode and return only valid samples Y_paded = states.pad_left(Y, self.ahead_ifft, 3) y = self.decode(Y_paded) y = y[:, tail_size:-self.ahead_ifft * self.hop_length] assert not states.state_old return y, Y, states.state