def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad if prop_down[0]: if accum[0]: g_x0 -= g_dx0 * dy * F.cos(x0) else: g_x0.copy_from(-g_dx0 * dy * F.cos(x0)) if prop_down[1]: if accum[1]: g_dy -= g_dx0 * F.sin(x0) else: g_dy.copy_from(-g_dx0 * F.sin(x0))
def tan_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] dx0 = dy * F.cos(x0)**(-2) return dx0
def sinusoidal_embedding(timesteps, embedding_dim): """ Sinusoidal embeddings originally proposed in "Attention Is All You Need" (https://arxiv.org/abs/1706.03762). """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 denominator = -np.log(10000) / half_dim emb = F.exp(denominator * F.arange(start=0, stop=half_dim)) emb = F.reshape(timesteps, (-1, 1)) * F.reshape(emb, (1, -1)) emb = F.concatenate(F.cos(emb), F.sin(emb), axis=1) if embedding_dim & 1: # zero pad to be divisible by two emb = F.pad(emb, [[0, 0], [0, 1]]) assert emb.shape == (timesteps.shape[0], embedding_dim) return emb
def sinc_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] m0 = F.not_equal_scalar(x0, 0) m0 = no_grad(m0) y0 = get_output(x0, "Sinc") dx0 = dy * (F.cos(x0) - y0) / x0 c0 = F.constant(0, x0.shape) dx0 = F.where(m0, dx0, c0) return dx0
def positional_encoding(x, N=6, include_input=True): """ Args: x: Input (B, R, 3) N: Number of bands, N=6 for implicit network and N=4 for rendering network. """ gamma = [x] if include_input else [] bands = 2**np.arange(0, N + 1) data_holder = nn.Variable if isinstance(x, nn.Variable) else nn.NdArray bands = data_holder.from_numpy_array(bands) bands = F.reshape(bands, tuple([1] * x.ndim) + (N + 1, )) \ * F.reshape(x, x.shape + (1, )) bands = F.reshape(bands, bands.shape[:-2] + (-1, )) cos_x = F.cos(bands) sin_x = F.sin(bands) gamma += [cos_x, sin_x] gamma = F.concatenate(*gamma, axis=-1) return gamma
def position_encoding(x: nn.Variable) -> nn.Variable: batch_size, sequence_length, dim = x.shape position = F.reshape(F.arange(0, sequence_length), shape=(sequence_length, 1)) # -> (sequence_length, 1) div_term = F.exp(F.arange(0, dim, 2) * -(np.log(10000.0) / dim)) # -> (dim//2, ) sin_val = F.sin(position * F.reshape(div_term, shape=(1, dim // 2))) # -> (sequence_length, dim//2) cos_val = F.cos(position * F.reshape(div_term, shape=(1, dim // 2))) # -> (sequence_length, dim//2) ret = [] for i in range(dim): if i % 2 == 0: ret.append(sin_val[:, i // 2:i // 2 + 1]) else: ret.append(cos_val[:, i // 2:i // 2 + 1]) pe = F.reshape(F.concatenate(*ret, axis=1), shape=(1, sequence_length, dim)) return x + F.broadcast(pe, shape=x.shape)
def test_unnecessary_traverse_1(self): a0 = nn.Variable((2, 3), need_grad=False) # `a1` will not be recomputed since `a2` will not be cleared. a1 = F.sin(a0).apply(recompute=True) a2 = F.cos(a1) a3 = F.sin(a2).apply(recompute=True) # 'a3` will be recomputed. b0 = nn.Variable((2, 3), need_grad=True).apply(recompute=True) b1 = F.identity(b0).apply(recompute=True) c = F.mul2(a3, b1).apply(recompute=True) # Check recomputation recursion stops when `a3.data` is calculated. c.forward(clear_buffer=False) # `a1.data` is cleared because `recompute` flag is `true`. assert(a1.data.clear_called == True) # `a2.data` is not cleared because `recompute` flag is `false`. assert(a2.data.clear_called == False) c.backward(clear_buffer=False) # If the recursive call reached to `a1`, `a1.data` should be set by recomputation. # However, the recursive call stops at `a2` whose data is not cleared. assert(a1.data.clear_called == True)
def slerp(noise_1, noise_2, ratio): interpolated_noises = [] for a, b in zip(noise_1, noise_2): a_norm = F.pow_scalar(F.sum(F.pow_scalar(a, 2), axis=1, keepdims=True), 0.5) b_norm = F.pow_scalar(F.sum(F.pow_scalar(b, 2), axis=1, keepdims=True), 0.5) a /= a_norm b /= b_norm d = F.sum(a * b, axis=1, keepdims=True) p = ratio * F.acos(d) c = b - d * a c_norm = F.pow_scalar(F.sum(F.pow_scalar(c, 2), axis=1, keepdims=True), 0.5) c /= c_norm d = a * F.cos(p) + c * F.sin(p) d = d / F.pow_scalar(F.sum(F.pow_scalar(d, 2), axis=1, keepdims=True), 0.5) interpolated_noises.append(d) return interpolated_noises
def graph(x): y = F.sin(x).apply(recompute=True) y = F.cos(y) return y
def __call__(self, x, test=False): fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop) x_theta = F.atan2(fft_imag, fft_real) x = Spectrogram(fft_real, fft_imag, power=self.power, mono=(self.nb_channels == 1)) nb_frames, nb_samples, nb_channels, nb_bins = x.shape mix_spec = F.identity(x) x = x[..., :self.nb_bins] # clone x_bass = F.identity(x) x_drums = F.identity(x) x_vocals = F.identity(x) x_other = F.identity(x) # shift and scale input to mean=0 std=1 (across all bins) x_bass += F.reshape(self.input_mean_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums += F.reshape(self.input_mean_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals += F.reshape(self.input_mean_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other += F.reshape(self.input_mean_other, shape=(1, 1, 1, self.nb_bins), inplace=False) x_bass *= F.reshape(self.input_scale_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums *= F.reshape(self.input_scale_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals *= F.reshape(self.input_scale_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other *= F.reshape(self.input_scale_other, shape=(1, 1, 1, self.nb_bins), inplace=False) # encode and normalize every instance in a batch x_bass = self.fc_bn(x_bass, self.hidden_size, "fc1_bass", test, activation='tanh') x_drums = self.fc_bn(x_drums, self.hidden_size, "fc1_drums", test, activation='tanh') x_vocals = self.fc_bn(x_vocals, self.hidden_size, "fc1_vocals", test, activation='tanh') x_other = self.fc_bn(x_other, self.hidden_size, "fc1_other", test, activation='tanh') # Average the sources cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # apply 3-layers of stacked LSTM lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test) lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test) lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test) lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test) # lstm skip connection x_bass = F.concatenate(x_bass, lstm_out_bass) x_drums = F.concatenate(x_drums, lstm_out_drums) x_vocals = F.concatenate(x_vocals, lstm_out_vocals) x_other = F.concatenate(x_other, lstm_out_other) cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # first dense stage + batch norm x_bass = self.fc_bn(cross_2, self.hidden_size, "fc2_bass", test, activation='relu') x_drums = self.fc_bn(cross_2, self.hidden_size, "fc2_drums", test, activation='relu') x_vocals = self.fc_bn(cross_2, self.hidden_size, "fc2_vocals", test, activation='relu') x_other = self.fc_bn(cross_2, self.hidden_size, "fc2_other", test, activation='relu') # second dense stage + batch norm x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test) x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test) x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals", test) x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test) # reshape back to original dim x_bass = F.reshape( x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_drums = F.reshape( x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_vocals = F.reshape( x_vocals, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_other = F.reshape( x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) # apply output scaling x_bass *= F.reshape(self.output_scale_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums *= F.reshape(self.output_scale_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals *= F.reshape(self.output_scale_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other *= F.reshape(self.output_scale_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_bass += F.reshape(self.output_mean_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums += F.reshape(self.output_mean_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals += F.reshape(self.output_mean_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other += F.reshape(self.output_mean_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) # since our output is non-negative, we can apply RELU mask_bass = F.relu(x_bass) mask_drums = F.relu(x_drums) mask_vocals = F.relu(x_vocals) mask_other = F.relu(x_other) # (Frames, Bsize, Channels, Fbins) x_bass = mask_bass * mix_spec x_drums = mask_drums * mix_spec x_vocals = mask_vocals * mix_spec x_other = mask_other * mix_spec if not self.is_predict: tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0) # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames) tmp = F.transpose(tmp, (0, 2, 3, 4, 1)) pred_r, pred_i = [], [] for i in range(tmp.shape[0]): pred_r.append(tmp[i] * F.cos(x_theta)) pred_i.append(tmp[i] * F.sin(x_theta)) pred_r = F.stack(*pred_r, axis=0) pred_i = F.stack(*pred_i, axis=0) pred_r = F.reshape(pred_r, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred_i = F.reshape(pred_i, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred = istft(pred_r, pred_i, self.n_fft, self.n_hop, self.n_fft, window_type='hanning', center=True) pred = F.reshape(pred, (4, nb_samples, nb_channels, -1)) else: pred = None return mix_spec, F.concatenate(mask_bass, mask_drums, mask_vocals, mask_other, axis=2), pred