Пример #1
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        if prop_down[0]:
            if accum[0]:
                g_x0 -= g_dx0 * dy * F.cos(x0)
            else:
                g_x0.copy_from(-g_dx0 * dy * F.cos(x0))

        if prop_down[1]:
            if accum[1]:
                g_dy -= g_dx0 * F.sin(x0)
            else:
                g_dy.copy_from(-g_dx0 * F.sin(x0))
Пример #2
0
def tan_backward(inputs):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    dx0 = dy * F.cos(x0)**(-2)
    return dx0
Пример #3
0
def sinusoidal_embedding(timesteps, embedding_dim):
    """
    Sinusoidal embeddings originally proposed in "Attention Is All You Need" (https://arxiv.org/abs/1706.03762).
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    denominator = -np.log(10000) / half_dim
    emb = F.exp(denominator * F.arange(start=0, stop=half_dim))
    emb = F.reshape(timesteps, (-1, 1)) * F.reshape(emb, (1, -1))
    emb = F.concatenate(F.cos(emb), F.sin(emb), axis=1)

    if embedding_dim & 1:  # zero pad to be divisible by two
        emb = F.pad(emb, [[0, 0], [0, 1]])

    assert emb.shape == (timesteps.shape[0], embedding_dim)

    return emb
Пример #4
0
def sinc_backward(inputs):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    m0 = F.not_equal_scalar(x0, 0)
    m0 = no_grad(m0)
    y0 = get_output(x0, "Sinc")

    dx0 = dy * (F.cos(x0) - y0) / x0
    c0 = F.constant(0, x0.shape)
    dx0 = F.where(m0, dx0, c0)
    return dx0
Пример #5
0
def positional_encoding(x, N=6, include_input=True):
    """
    Args:
      x: Input (B, R, 3)
      N: Number of bands, N=6 for implicit network and N=4 for rendering network.
    """

    gamma = [x] if include_input else []
    bands = 2**np.arange(0, N + 1)
    data_holder = nn.Variable if isinstance(x, nn.Variable) else nn.NdArray
    bands = data_holder.from_numpy_array(bands)
    bands = F.reshape(bands, tuple([1] * x.ndim) + (N + 1, )) \
        * F.reshape(x, x.shape + (1, ))
    bands = F.reshape(bands, bands.shape[:-2] + (-1, ))
    cos_x = F.cos(bands)
    sin_x = F.sin(bands)

    gamma += [cos_x, sin_x]
    gamma = F.concatenate(*gamma, axis=-1)

    return gamma
Пример #6
0
def position_encoding(x: nn.Variable) -> nn.Variable:
    batch_size, sequence_length, dim = x.shape

    position = F.reshape(F.arange(0, sequence_length),
                         shape=(sequence_length, 1))
    # -> (sequence_length, 1)
    div_term = F.exp(F.arange(0, dim, 2) * -(np.log(10000.0) / dim))
    # -> (dim//2, )
    sin_val = F.sin(position * F.reshape(div_term, shape=(1, dim // 2)))
    # -> (sequence_length, dim//2)
    cos_val = F.cos(position * F.reshape(div_term, shape=(1, dim // 2)))
    # -> (sequence_length, dim//2)
    ret = []
    for i in range(dim):
        if i % 2 == 0:
            ret.append(sin_val[:, i // 2:i // 2 + 1])
        else:
            ret.append(cos_val[:, i // 2:i // 2 + 1])
    pe = F.reshape(F.concatenate(*ret, axis=1),
                   shape=(1, sequence_length, dim))
    return x + F.broadcast(pe, shape=x.shape)
Пример #7
0
    def test_unnecessary_traverse_1(self):
        a0 = nn.Variable((2, 3), need_grad=False)
        # `a1` will not be recomputed since `a2` will not be cleared.
        a1 = F.sin(a0).apply(recompute=True)
        a2 = F.cos(a1)
        a3 = F.sin(a2).apply(recompute=True)  # 'a3` will be recomputed.

        b0 = nn.Variable((2, 3), need_grad=True).apply(recompute=True)
        b1 = F.identity(b0).apply(recompute=True)

        c = F.mul2(a3, b1).apply(recompute=True)

        # Check recomputation recursion stops when `a3.data` is calculated.

        c.forward(clear_buffer=False)
        # `a1.data` is cleared because `recompute` flag is `true`.
        assert(a1.data.clear_called == True)
        # `a2.data` is not cleared because `recompute` flag is `false`.
        assert(a2.data.clear_called == False)
        c.backward(clear_buffer=False)
        # If the recursive call reached to `a1`, `a1.data` should be set by recomputation.
        # However, the recursive call stops at `a2` whose data is not cleared.
        assert(a1.data.clear_called == True)
Пример #8
0
def slerp(noise_1, noise_2, ratio):
    interpolated_noises = []
    for a, b in zip(noise_1, noise_2):
        a_norm = F.pow_scalar(F.sum(F.pow_scalar(a, 2), axis=1, keepdims=True),
                              0.5)
        b_norm = F.pow_scalar(F.sum(F.pow_scalar(b, 2), axis=1, keepdims=True),
                              0.5)

        a /= a_norm
        b /= b_norm

        d = F.sum(a * b, axis=1, keepdims=True)
        p = ratio * F.acos(d)
        c = b - d * a
        c_norm = F.pow_scalar(F.sum(F.pow_scalar(c, 2), axis=1, keepdims=True),
                              0.5)
        c /= c_norm

        d = a * F.cos(p) + c * F.sin(p)
        d = d / F.pow_scalar(F.sum(F.pow_scalar(d, 2), axis=1, keepdims=True),
                             0.5)

        interpolated_noises.append(d)
    return interpolated_noises
Пример #9
0
 def graph(x):
     y = F.sin(x).apply(recompute=True)
     y = F.cos(y)
     return y
Пример #10
0
    def __call__(self, x, test=False):

        fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop)
        x_theta = F.atan2(fft_imag, fft_real)

        x = Spectrogram(fft_real,
                        fft_imag,
                        power=self.power,
                        mono=(self.nb_channels == 1))

        nb_frames, nb_samples, nb_channels, nb_bins = x.shape

        mix_spec = F.identity(x)
        x = x[..., :self.nb_bins]

        # clone
        x_bass = F.identity(x)
        x_drums = F.identity(x)
        x_vocals = F.identity(x)
        x_other = F.identity(x)

        # shift and scale input to mean=0 std=1 (across all bins)
        x_bass += F.reshape(self.input_mean_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums += F.reshape(self.input_mean_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals += F.reshape(self.input_mean_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other += F.reshape(self.input_mean_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        x_bass *= F.reshape(self.input_scale_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums *= F.reshape(self.input_scale_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.input_scale_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other *= F.reshape(self.input_scale_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        # encode and normalize every instance in a batch
        x_bass = self.fc_bn(x_bass,
                            self.hidden_size,
                            "fc1_bass",
                            test,
                            activation='tanh')
        x_drums = self.fc_bn(x_drums,
                             self.hidden_size,
                             "fc1_drums",
                             test,
                             activation='tanh')
        x_vocals = self.fc_bn(x_vocals,
                              self.hidden_size,
                              "fc1_vocals",
                              test,
                              activation='tanh')
        x_other = self.fc_bn(x_other,
                             self.hidden_size,
                             "fc1_other",
                             test,
                             activation='tanh')

        # Average the sources
        cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # apply 3-layers of stacked LSTM
        lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test)
        lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test)
        lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test)
        lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test)

        # lstm skip connection
        x_bass = F.concatenate(x_bass, lstm_out_bass)
        x_drums = F.concatenate(x_drums, lstm_out_drums)
        x_vocals = F.concatenate(x_vocals, lstm_out_vocals)
        x_other = F.concatenate(x_other, lstm_out_other)

        cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # first dense stage + batch norm
        x_bass = self.fc_bn(cross_2,
                            self.hidden_size,
                            "fc2_bass",
                            test,
                            activation='relu')
        x_drums = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_drums",
                             test,
                             activation='relu')
        x_vocals = self.fc_bn(cross_2,
                              self.hidden_size,
                              "fc2_vocals",
                              test,
                              activation='relu')
        x_other = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_other",
                             test,
                             activation='relu')

        # second dense stage + batch norm
        x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test)
        x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test)
        x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals",
                              test)
        x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test)

        # reshape back to original dim
        x_bass = F.reshape(
            x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_drums = F.reshape(
            x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_vocals = F.reshape(
            x_vocals,
            (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_other = F.reshape(
            x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))

        # apply output scaling
        x_bass *= F.reshape(self.output_scale_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums *= F.reshape(self.output_scale_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.output_scale_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other *= F.reshape(self.output_scale_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        x_bass += F.reshape(self.output_mean_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums += F.reshape(self.output_mean_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals += F.reshape(self.output_mean_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other += F.reshape(self.output_mean_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        # since our output is non-negative, we can apply RELU
        mask_bass = F.relu(x_bass)
        mask_drums = F.relu(x_drums)
        mask_vocals = F.relu(x_vocals)
        mask_other = F.relu(x_other)

        # (Frames, Bsize, Channels, Fbins)
        x_bass = mask_bass * mix_spec
        x_drums = mask_drums * mix_spec
        x_vocals = mask_vocals * mix_spec
        x_other = mask_other * mix_spec

        if not self.is_predict:
            tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0)
            # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames)
            tmp = F.transpose(tmp, (0, 2, 3, 4, 1))
            pred_r, pred_i = [], []
            for i in range(tmp.shape[0]):
                pred_r.append(tmp[i] * F.cos(x_theta))
                pred_i.append(tmp[i] * F.sin(x_theta))
            pred_r = F.stack(*pred_r, axis=0)
            pred_i = F.stack(*pred_i, axis=0)
            pred_r = F.reshape(pred_r,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred_i = F.reshape(pred_i,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred = istft(pred_r,
                         pred_i,
                         self.n_fft,
                         self.n_hop,
                         self.n_fft,
                         window_type='hanning',
                         center=True)
            pred = F.reshape(pred, (4, nb_samples, nb_channels, -1))

        else:
            pred = None

        return mix_spec, F.concatenate(mask_bass,
                                       mask_drums,
                                       mask_vocals,
                                       mask_other,
                                       axis=2), pred