def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad if prop_down[0]: if accum[0]: g_x0 -= g_dx0 * dy * F.cos(x0) else: g_x0.copy_from(-g_dx0 * dy * F.cos(x0)) if prop_down[1]: if accum[1]: g_dy -= g_dx0 * F.sin(x0) else: g_dy.copy_from(-g_dx0 * F.sin(x0))
def graph(x0): x1 = F.sin(x0).apply(recompute=True) # Set `recompute` and `persistent` flag at the same time x2 = F.sin(x1).apply(recompute=True, persistent=True) x3 = F.sin(x2).apply(recompute=True) y = F.sin(x3) return y
def graph(x0): x1 = F.sin(x0).apply(recompute=True) # Set `recompute` flag to the inplaced variable. x2 = F.reshape(x1, (3, 2), inplace=True).apply(recompute=True) x3 = F.sin(x2).apply(recompute=True) y = F.sin(x3) return y
def test_sequential_with_statement(self, f1, f2): """ Test for sequential use of with statement. """ x = nn.Variable((2, 3)) assert x.recompute == False # First `with` block with nn.recompute(f1): y = F.relu(x) assert y.recompute == f1 y = F.sin(y) assert y.recompute == f1 assert y.recompute == f1 y = F.relu(y) assert y.recompute == False # Second `with` block with nn.recompute(f2): y = F.relu(x) assert y.recompute == f2 y = F.sin(y) assert y.recompute == f2 assert y.recompute == f2 y = F.relu(y) assert y.recompute == False
def test_recompute_flag(self): x0 = nn.Variable((1, 1), need_grad=True) x1 = F.sin(x0).apply(recompute=True) x2 = F.sin(x1).apply(recompute=False) x3 = F.sin(x2) assert x0.recompute == False assert x1.recompute == True assert x2.recompute == False assert x3.recompute == False
def test_checkpoint(self): x0 = nn.Variable((2, 3), need_grad=True) x1 = F.sin(x0).apply(recompute=True) x2 = F.sin(x1).apply(recompute=True) x3 = F.sin(x2) # Checkpoint 1 (recompute == False) x4 = F.sin(x3).apply(recompute=True) x5 = F.sin(x4).apply(recompute=True) x6 = F.sin(x5) # Checkpoint 2 (recompute == False) x7 = F.sin(x6).apply(recompute=True) x8 = F.sin(x7).apply(recompute=True) # All intermediate data except checkpoints will be cleared during forward propagation. x8.forward(clear_no_need_grad=True) # Trace clear_called flags of `x2` and `x5` during backward propagation. # clear_called flag changes True to False when the data is recomputed. act_flags = [] def get_clear_called_flags(nnabla_func): act_flags.append([x2.data.clear_called, x5.data.clear_called]) x8.backward(function_post_hook=get_clear_called_flags) ref_flags = [ # [x2, x5] clear_called flags [True, True], # After F.sin(x7) backward [True, True], # After F.sin(x6) backward [True, False], # After F.sin(x5) backward [True, False], # After F.sin(x4) backward [True, False], # After F.sin(x3) backward [False, False], # After F.sin(x2) backward [False, False], # After F.sin(x1) backward [False, False], # After F.sin(x0) backward ] assert(ref_flags == act_flags)
def test_clear_data_on_not_bwd_path(self): a0 = nn.Variable((2, 3), need_grad=True) a1 = F.identity(a0).apply(recompute=True) a2 = F.sin(a1).apply(recompute=True) # These three variables are not back-propagated. b0 = nn.Variable((2, 3), need_grad=False) b1 = F.identity(b0).apply(recompute=True) b2 = F.sin(b1).apply(recompute=True) c1 = F.add2(a2, b2).apply(recompute=True) c2 = F.sin(c1) # Forward clear_called_flag_recorder.activate_clear_called_flag_recorder() c2.forward(clear_no_need_grad=True) # Data which will be recomputed must be cleared during forward propagation. expected = [ [False], # a0 [True], # a1 [False], # b0 [True], # b1 [True, True], # a2, b2 [True], # c1 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Backward clear_called_flag_recorder.activate_clear_called_flag_recorder() c2.backward(clear_buffer=True) # b1 is not on backward path and must be cleared during recomputation. expected = [ # Recomputation [False], # a0 [False], # a1 [False], # b0 [True], # b1 (not on backward path) must be cleared [True, True], # a2, b2 [False], # c1 # Backward propagation [True, True], # a2, b2 [False], # a1 [False], # a0 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def test_double_backward_floating_variables(): x = nn.Variable((2, 2), need_grad=True) y = nn.Variable((2, 3), need_grad=True) z = nn.Variable((2, 4), need_grad=True) w = F.concatenate(*[x, y, z], axis=-1) o = F.sin(w) dx = nn.grad([o], [x])[0] ddx = nn.grad([dx], [x])[0] # Error must not happen
def test_clear_input_data(self): x0 = nn.Variable((1, 1), need_grad=True) # `F.sin` input data is always needed for grad calculation x1 = F.sin(x0).apply(recompute=True) x2 = F.sin(x1).apply(recompute=False) x3 = F.sin(x2) answer = [] answer.append([False]) # x0 answer.append([True]) # x1 answer.append([False]) # x2 clear_called_flag_recorder.activate_clear_called_flag_recorder() x3.forward(clear_no_need_grad=True) self.check_input_data_clear_called_flags(answer) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def test_clearing_without_recompute_flag(self): x0 = nn.Variable((1, 128, 128), need_grad=True) x1 = F.sin(x0).apply(recompute=True) x2 = F.dropout(x1) x3 = F.sin(x2).apply(recompute=True) x4 = F.sin(x3).apply(recompute=True) y = F.identity(x4) # Skip this code temporarily since it cause # randomly crash when perform CI testing on windows 10 with nnabla-cuda-ext pytest.skip( 'Skipped for randomly crash when perform CI testing on windows 10 with nnabla-cuda-ext') y.forward(clear_no_need_grad=True) x2.data.clear() with pytest.raises(RuntimeError, match="Failed `called_setup_recompute_`"): # x2.data cannot be recomputed correctly since `setup_recompute` is not called during forward propagation. # Backward should raise when some intermediate variables are cleared by user. y.backward()
def cos_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] dx0 = - dy * F.sin(x0) return dx0
def test_unnecessary_traverse_0(self): # No need grad path a0 = nn.Variable((2, 3), need_grad=False) a1 = F.sin(a0).apply(recompute=True) # Need grad path b0 = nn.Variable((2, 3), need_grad=True) b1 = F.sin(b0).apply(recompute=True) # branch c = F.add2(a1, b1) # Check whether unnecessary recomputation for `a1.data` is performed. c.forward(clear_no_need_grad=True) assert(a1.data.clear_called == True) assert(b1.data.clear_called == True) # Exec backward without clearing buffer to check whether recomputation is performed by seeing `clear_called` flag. c.backward(clear_buffer=False) # a1.data is still cleared. (Recalculation is not performed) assert(a1.data.clear_called == True) # b1.data is set. (Recalculation is performed) assert(b1.data.clear_called == False)
def test_clear_no_need_grad_during_recomputation(self): x0 = nn.Variable((2, 3), need_grad=True) x1 = F.identity(x0).apply(recompute=True) # x2.data must be cleared just after recomputation because they are not need for backward propagation. x2 = F.sin(x1).apply(recompute=True) x3 = F.identity(x2).apply(recompute=True) x4 = F.sin(x3) # Forward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.forward(clear_no_need_grad=True) # All intermediate data must be cleared. expected = [ [False], # x0 [True], # x1 [True], # x2 [True], # x3 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Backward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.backward(clear_buffer=True) expected = [ # Recomputation [False], # x0 [False], # x1 [True], # x2: not need for grad calculation # Backward propagation [False], # x3 [True], # x2 [False], # x1 [False], # x0 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def test_unnecessary_traverse_1(self): a0 = nn.Variable((2, 3), need_grad=False) # `a1` will not be recomputed since `a2` will not be cleared. a1 = F.sin(a0).apply(recompute=True) a2 = F.cos(a1) a3 = F.sin(a2).apply(recompute=True) # 'a3` will be recomputed. b0 = nn.Variable((2, 3), need_grad=True).apply(recompute=True) b1 = F.identity(b0).apply(recompute=True) c = F.mul2(a3, b1).apply(recompute=True) # Check recomputation recursion stops when `a3.data` is calculated. c.forward(clear_buffer=False) # `a1.data` is cleared because `recompute` flag is `true`. assert(a1.data.clear_called == True) # `a2.data` is not cleared because `recompute` flag is `false`. assert(a2.data.clear_called == False) c.backward(clear_buffer=False) # If the recursive call reached to `a1`, `a1.data` should be set by recomputation. # However, the recursive call stops at `a2` whose data is not cleared. assert(a1.data.clear_called == True)
def test_recomputed_data_value(self, seed): rng = np.random.RandomState(seed) a0 = nn.Variable((2, 3), need_grad=True) b0 = nn.Variable((2, 3), need_grad=True) a0.d = rng.randn(*a0.shape) b0.d = rng.randn(*b0.shape) a1 = F.sin(a0).apply(recompute=True) a2 = F.sin(a1) a3 = F.sin(a2) b1 = F.sin(b0) b2 = F.sin(b1).apply(recompute=True) b3 = F.sin(b2) c0 = F.mul2(a3, b3).apply(recompute=True) c1 = F.sin(c0) # Forward # Get output data which will be recomputed. ref_data = [] # data of a0, b2 and c0 will be stored. def get_output_data(nnabla_func): outputs = nnabla_func.outputs for output in outputs: if output.recompute: ref_data.append(copy.deepcopy(output.d)) c1.forward(function_post_hook=get_output_data) # Backward # Get recomputed data act_data = [] def get_recomputed_data(nnabla_func): inputs = nnabla_func.inputs for input in inputs: if input.recompute: act_data.append(copy.deepcopy(input.d)) c1.backward(function_pre_hook=get_recomputed_data) # Make the order the same as `ref_data`. act_data.reverse() # Check recomputed data for act, ref in zip(act_data, ref_data): assert_allclose(act, ref, rtol=0, atol=0)
def sinusoidal_embedding(timesteps, embedding_dim): """ Sinusoidal embeddings originally proposed in "Attention Is All You Need" (https://arxiv.org/abs/1706.03762). """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 denominator = -np.log(10000) / half_dim emb = F.exp(denominator * F.arange(start=0, stop=half_dim)) emb = F.reshape(timesteps, (-1, 1)) * F.reshape(emb, (1, -1)) emb = F.concatenate(F.cos(emb), F.sin(emb), axis=1) if embedding_dim & 1: # zero pad to be divisible by two emb = F.pad(emb, [[0, 0], [0, 1]]) assert emb.shape == (timesteps.shape[0], embedding_dim) return emb
def positional_encoding(x, N=6, include_input=True): """ Args: x: Input (B, R, 3) N: Number of bands, N=6 for implicit network and N=4 for rendering network. """ gamma = [x] if include_input else [] bands = 2**np.arange(0, N + 1) data_holder = nn.Variable if isinstance(x, nn.Variable) else nn.NdArray bands = data_holder.from_numpy_array(bands) bands = F.reshape(bands, tuple([1] * x.ndim) + (N + 1, )) \ * F.reshape(x, x.shape + (1, )) bands = F.reshape(bands, bands.shape[:-2] + (-1, )) cos_x = F.cos(bands) sin_x = F.sin(bands) gamma += [cos_x, sin_x] gamma = F.concatenate(*gamma, axis=-1) return gamma
def position_encoding(x: nn.Variable) -> nn.Variable: batch_size, sequence_length, dim = x.shape position = F.reshape(F.arange(0, sequence_length), shape=(sequence_length, 1)) # -> (sequence_length, 1) div_term = F.exp(F.arange(0, dim, 2) * -(np.log(10000.0) / dim)) # -> (dim//2, ) sin_val = F.sin(position * F.reshape(div_term, shape=(1, dim // 2))) # -> (sequence_length, dim//2) cos_val = F.cos(position * F.reshape(div_term, shape=(1, dim // 2))) # -> (sequence_length, dim//2) ret = [] for i in range(dim): if i % 2 == 0: ret.append(sin_val[:, i // 2:i // 2 + 1]) else: ret.append(cos_val[:, i // 2:i // 2 + 1]) pe = F.reshape(F.concatenate(*ret, axis=1), shape=(1, sequence_length, dim)) return x + F.broadcast(pe, shape=x.shape)
def slerp(noise_1, noise_2, ratio): interpolated_noises = [] for a, b in zip(noise_1, noise_2): a_norm = F.pow_scalar(F.sum(F.pow_scalar(a, 2), axis=1, keepdims=True), 0.5) b_norm = F.pow_scalar(F.sum(F.pow_scalar(b, 2), axis=1, keepdims=True), 0.5) a /= a_norm b /= b_norm d = F.sum(a * b, axis=1, keepdims=True) p = ratio * F.acos(d) c = b - d * a c_norm = F.pow_scalar(F.sum(F.pow_scalar(c, 2), axis=1, keepdims=True), 0.5) c /= c_norm d = a * F.cos(p) + c * F.sin(p) d = d / F.pow_scalar(F.sum(F.pow_scalar(d, 2), axis=1, keepdims=True), 0.5) interpolated_noises.append(d) return interpolated_noises
def graph(x): y = F.sin(x).apply(recompute=True) y = F.cos(y) return y
def __call__(self, x, test=False): fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop) x_theta = F.atan2(fft_imag, fft_real) x = Spectrogram(fft_real, fft_imag, power=self.power, mono=(self.nb_channels == 1)) nb_frames, nb_samples, nb_channels, nb_bins = x.shape mix_spec = F.identity(x) x = x[..., :self.nb_bins] # clone x_bass = F.identity(x) x_drums = F.identity(x) x_vocals = F.identity(x) x_other = F.identity(x) # shift and scale input to mean=0 std=1 (across all bins) x_bass += F.reshape(self.input_mean_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums += F.reshape(self.input_mean_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals += F.reshape(self.input_mean_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other += F.reshape(self.input_mean_other, shape=(1, 1, 1, self.nb_bins), inplace=False) x_bass *= F.reshape(self.input_scale_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums *= F.reshape(self.input_scale_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals *= F.reshape(self.input_scale_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other *= F.reshape(self.input_scale_other, shape=(1, 1, 1, self.nb_bins), inplace=False) # encode and normalize every instance in a batch x_bass = self.fc_bn(x_bass, self.hidden_size, "fc1_bass", test, activation='tanh') x_drums = self.fc_bn(x_drums, self.hidden_size, "fc1_drums", test, activation='tanh') x_vocals = self.fc_bn(x_vocals, self.hidden_size, "fc1_vocals", test, activation='tanh') x_other = self.fc_bn(x_other, self.hidden_size, "fc1_other", test, activation='tanh') # Average the sources cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # apply 3-layers of stacked LSTM lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test) lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test) lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test) lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test) # lstm skip connection x_bass = F.concatenate(x_bass, lstm_out_bass) x_drums = F.concatenate(x_drums, lstm_out_drums) x_vocals = F.concatenate(x_vocals, lstm_out_vocals) x_other = F.concatenate(x_other, lstm_out_other) cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # first dense stage + batch norm x_bass = self.fc_bn(cross_2, self.hidden_size, "fc2_bass", test, activation='relu') x_drums = self.fc_bn(cross_2, self.hidden_size, "fc2_drums", test, activation='relu') x_vocals = self.fc_bn(cross_2, self.hidden_size, "fc2_vocals", test, activation='relu') x_other = self.fc_bn(cross_2, self.hidden_size, "fc2_other", test, activation='relu') # second dense stage + batch norm x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test) x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test) x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals", test) x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test) # reshape back to original dim x_bass = F.reshape( x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_drums = F.reshape( x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_vocals = F.reshape( x_vocals, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_other = F.reshape( x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) # apply output scaling x_bass *= F.reshape(self.output_scale_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums *= F.reshape(self.output_scale_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals *= F.reshape(self.output_scale_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other *= F.reshape(self.output_scale_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_bass += F.reshape(self.output_mean_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums += F.reshape(self.output_mean_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals += F.reshape(self.output_mean_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other += F.reshape(self.output_mean_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) # since our output is non-negative, we can apply RELU mask_bass = F.relu(x_bass) mask_drums = F.relu(x_drums) mask_vocals = F.relu(x_vocals) mask_other = F.relu(x_other) # (Frames, Bsize, Channels, Fbins) x_bass = mask_bass * mix_spec x_drums = mask_drums * mix_spec x_vocals = mask_vocals * mix_spec x_other = mask_other * mix_spec if not self.is_predict: tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0) # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames) tmp = F.transpose(tmp, (0, 2, 3, 4, 1)) pred_r, pred_i = [], [] for i in range(tmp.shape[0]): pred_r.append(tmp[i] * F.cos(x_theta)) pred_i.append(tmp[i] * F.sin(x_theta)) pred_r = F.stack(*pred_r, axis=0) pred_i = F.stack(*pred_i, axis=0) pred_r = F.reshape(pred_r, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred_i = F.reshape(pred_i, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred = istft(pred_r, pred_i, self.n_fft, self.n_hop, self.n_fft, window_type='hanning', center=True) pred = F.reshape(pred, (4, nb_samples, nb_channels, -1)) else: pred = None return mix_spec, F.concatenate(mask_bass, mask_drums, mask_vocals, mask_other, axis=2), pred
def func2(x): assert x.recompute == f0 y = F.sin(x) assert y.recompute == f2 return y