예제 #1
0
파일: cos.py 프로젝트: zeta1999/nnabla
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        if prop_down[0]:
            if accum[0]:
                g_x0 -= g_dx0 * dy * F.cos(x0)
            else:
                g_x0.copy_from(-g_dx0 * dy * F.cos(x0))

        if prop_down[1]:
            if accum[1]:
                g_dy -= g_dx0 * F.sin(x0)
            else:
                g_dy.copy_from(-g_dx0 * F.sin(x0))
예제 #2
0
 def graph(x0):
     x1 = F.sin(x0).apply(recompute=True)
     # Set `recompute` and `persistent` flag at the same time
     x2 = F.sin(x1).apply(recompute=True, persistent=True)
     x3 = F.sin(x2).apply(recompute=True)
     y = F.sin(x3)
     return y
예제 #3
0
 def graph(x0):
     x1 = F.sin(x0).apply(recompute=True)
     # Set `recompute` flag to the inplaced variable.
     x2 = F.reshape(x1, (3, 2), inplace=True).apply(recompute=True)
     x3 = F.sin(x2).apply(recompute=True)
     y = F.sin(x3)
     return y
예제 #4
0
    def test_sequential_with_statement(self, f1, f2):
        """
        Test for sequential use of with statement.
        """
        x = nn.Variable((2, 3))
        assert x.recompute == False

        # First `with` block
        with nn.recompute(f1):
            y = F.relu(x)
            assert y.recompute == f1
            y = F.sin(y)
            assert y.recompute == f1

        assert y.recompute == f1

        y = F.relu(y)
        assert y.recompute == False

        # Second `with` block
        with nn.recompute(f2):
            y = F.relu(x)
            assert y.recompute == f2
            y = F.sin(y)
            assert y.recompute == f2

        assert y.recompute == f2

        y = F.relu(y)
        assert y.recompute == False
예제 #5
0
    def test_recompute_flag(self):
        x0 = nn.Variable((1, 1), need_grad=True)
        x1 = F.sin(x0).apply(recompute=True)
        x2 = F.sin(x1).apply(recompute=False)
        x3 = F.sin(x2)

        assert x0.recompute == False
        assert x1.recompute == True
        assert x2.recompute == False
        assert x3.recompute == False
예제 #6
0
    def test_checkpoint(self):
        x0 = nn.Variable((2, 3), need_grad=True)

        x1 = F.sin(x0).apply(recompute=True)
        x2 = F.sin(x1).apply(recompute=True)
        x3 = F.sin(x2)  # Checkpoint 1 (recompute == False)
        x4 = F.sin(x3).apply(recompute=True)
        x5 = F.sin(x4).apply(recompute=True)
        x6 = F.sin(x5)  # Checkpoint 2 (recompute == False)
        x7 = F.sin(x6).apply(recompute=True)
        x8 = F.sin(x7).apply(recompute=True)

        # All intermediate data except checkpoints will be cleared during forward propagation.
        x8.forward(clear_no_need_grad=True)

        # Trace clear_called flags of `x2` and `x5` during backward propagation.
        # clear_called flag changes True to False when the data is recomputed.
        act_flags = []

        def get_clear_called_flags(nnabla_func):
            act_flags.append([x2.data.clear_called, x5.data.clear_called])
        x8.backward(function_post_hook=get_clear_called_flags)
        ref_flags = [
                     # [x2, x5] clear_called flags
                     [True, True],  # After F.sin(x7) backward
                     [True, True],  # After F.sin(x6) backward
                     [True, False],  # After F.sin(x5) backward
                     [True, False],  # After F.sin(x4) backward
                     [True, False],  # After F.sin(x3) backward
                     [False, False],  # After F.sin(x2) backward
                     [False, False],  # After F.sin(x1) backward
                     [False, False],  # After F.sin(x0) backward
                    ]

        assert(ref_flags == act_flags)
예제 #7
0
    def test_clear_data_on_not_bwd_path(self):
        a0 = nn.Variable((2, 3), need_grad=True)
        a1 = F.identity(a0).apply(recompute=True)
        a2 = F.sin(a1).apply(recompute=True)

        # These three variables are not back-propagated.
        b0 = nn.Variable((2, 3), need_grad=False)
        b1 = F.identity(b0).apply(recompute=True)
        b2 = F.sin(b1).apply(recompute=True)

        c1 = F.add2(a2, b2).apply(recompute=True)
        c2 = F.sin(c1)

        # Forward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        c2.forward(clear_no_need_grad=True)
        # Data which will be recomputed must be cleared during forward propagation.
        expected = [
            [False],  # a0
            [True],  # a1
            [False],  # b0
            [True],  # b1
            [True, True],  # a2, b2
            [True],  # c1
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Backward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        c2.backward(clear_buffer=True)
        # b1 is not on backward path and must be cleared during recomputation.
        expected = [
            # Recomputation
            [False],  # a0
            [False],  # a1
            [False],  # b0
            [True],  # b1 (not on backward path) must be cleared
            [True, True],  # a2, b2
            [False],  # c1
            # Backward propagation
            [True, True],  # a2, b2
            [False],  # a1
            [False],  # a0
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
예제 #8
0
파일: test_grad.py 프로젝트: sony/nnabla
def test_double_backward_floating_variables():
    x = nn.Variable((2, 2), need_grad=True)
    y = nn.Variable((2, 3), need_grad=True)
    z = nn.Variable((2, 4), need_grad=True)
    w = F.concatenate(*[x, y, z], axis=-1)
    o = F.sin(w)
    dx = nn.grad([o], [x])[0]
    ddx = nn.grad([dx], [x])[0]  # Error must not happen
예제 #9
0
    def test_clear_input_data(self):
        x0 = nn.Variable((1, 1), need_grad=True)
        # `F.sin` input data is always needed for grad calculation
        x1 = F.sin(x0).apply(recompute=True)
        x2 = F.sin(x1).apply(recompute=False)
        x3 = F.sin(x2)

        answer = []
        answer.append([False])  # x0
        answer.append([True])  # x1
        answer.append([False])  # x2

        clear_called_flag_recorder.activate_clear_called_flag_recorder()

        x3.forward(clear_no_need_grad=True)
        self.check_input_data_clear_called_flags(answer)

        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
예제 #10
0
    def test_clearing_without_recompute_flag(self):
        x0 = nn.Variable((1, 128, 128), need_grad=True)
        x1 = F.sin(x0).apply(recompute=True)
        x2 = F.dropout(x1)
        x3 = F.sin(x2).apply(recompute=True)
        x4 = F.sin(x3).apply(recompute=True)
        y = F.identity(x4)

        # Skip this code temporarily since it cause
        # randomly crash when perform CI testing on windows 10 with nnabla-cuda-ext
        pytest.skip(
            'Skipped for randomly crash when perform CI testing on windows 10 with nnabla-cuda-ext')

        y.forward(clear_no_need_grad=True)
        x2.data.clear()
        with pytest.raises(RuntimeError, match="Failed `called_setup_recompute_`"):
            # x2.data cannot be recomputed correctly since `setup_recompute` is not called during forward propagation.
            # Backward should raise when some intermediate variables are cleared by user.
            y.backward()
예제 #11
0
파일: cos.py 프로젝트: donproc/nnabla
def cos_backward(inputs):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    dx0 = - dy * F.sin(x0)
    return dx0
예제 #12
0
    def test_unnecessary_traverse_0(self):
        # No need grad path
        a0 = nn.Variable((2, 3), need_grad=False)
        a1 = F.sin(a0).apply(recompute=True)
        # Need grad path
        b0 = nn.Variable((2, 3), need_grad=True)
        b1 = F.sin(b0).apply(recompute=True)
        # branch
        c = F.add2(a1, b1)

        # Check whether unnecessary recomputation for `a1.data` is performed.

        c.forward(clear_no_need_grad=True)
        assert(a1.data.clear_called == True)
        assert(b1.data.clear_called == True)

        # Exec backward without clearing buffer to check whether recomputation is performed by seeing `clear_called` flag.
        c.backward(clear_buffer=False)
        # a1.data is still cleared. (Recalculation is not performed)
        assert(a1.data.clear_called == True)
        # b1.data is set. (Recalculation is performed)
        assert(b1.data.clear_called == False)
예제 #13
0
    def test_clear_no_need_grad_during_recomputation(self):
        x0 = nn.Variable((2, 3), need_grad=True)

        x1 = F.identity(x0).apply(recompute=True)
        # x2.data must be cleared just after recomputation because they are not need for backward propagation.
        x2 = F.sin(x1).apply(recompute=True)
        x3 = F.identity(x2).apply(recompute=True)
        x4 = F.sin(x3)

        # Forward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.forward(clear_no_need_grad=True)
        # All intermediate data must be cleared.
        expected = [
            [False],  # x0
            [True],  # x1
            [True],  # x2
            [True],  # x3
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Backward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.backward(clear_buffer=True)
        expected = [
            # Recomputation
            [False],  # x0
            [False],  # x1
            [True],  # x2: not need for grad calculation
            # Backward propagation
            [False],  # x3
            [True],  # x2
            [False],  # x1
            [False],  # x0
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
예제 #14
0
    def test_unnecessary_traverse_1(self):
        a0 = nn.Variable((2, 3), need_grad=False)
        # `a1` will not be recomputed since `a2` will not be cleared.
        a1 = F.sin(a0).apply(recompute=True)
        a2 = F.cos(a1)
        a3 = F.sin(a2).apply(recompute=True)  # 'a3` will be recomputed.

        b0 = nn.Variable((2, 3), need_grad=True).apply(recompute=True)
        b1 = F.identity(b0).apply(recompute=True)

        c = F.mul2(a3, b1).apply(recompute=True)

        # Check recomputation recursion stops when `a3.data` is calculated.

        c.forward(clear_buffer=False)
        # `a1.data` is cleared because `recompute` flag is `true`.
        assert(a1.data.clear_called == True)
        # `a2.data` is not cleared because `recompute` flag is `false`.
        assert(a2.data.clear_called == False)
        c.backward(clear_buffer=False)
        # If the recursive call reached to `a1`, `a1.data` should be set by recomputation.
        # However, the recursive call stops at `a2` whose data is not cleared.
        assert(a1.data.clear_called == True)
예제 #15
0
파일: test_graph.py 프로젝트: sony/nnabla
    def test_recomputed_data_value(self, seed):
        rng = np.random.RandomState(seed)
        a0 = nn.Variable((2, 3), need_grad=True)
        b0 = nn.Variable((2, 3), need_grad=True)
        a0.d = rng.randn(*a0.shape)
        b0.d = rng.randn(*b0.shape)

        a1 = F.sin(a0).apply(recompute=True)
        a2 = F.sin(a1)
        a3 = F.sin(a2)

        b1 = F.sin(b0)
        b2 = F.sin(b1).apply(recompute=True)
        b3 = F.sin(b2)

        c0 = F.mul2(a3, b3).apply(recompute=True)
        c1 = F.sin(c0)

        # Forward

        # Get output data which will be recomputed.
        ref_data = []  # data of a0, b2 and c0 will be stored.

        def get_output_data(nnabla_func):
            outputs = nnabla_func.outputs
            for output in outputs:
                if output.recompute:
                    ref_data.append(copy.deepcopy(output.d))

        c1.forward(function_post_hook=get_output_data)

        # Backward

        # Get recomputed data
        act_data = []

        def get_recomputed_data(nnabla_func):
            inputs = nnabla_func.inputs
            for input in inputs:
                if input.recompute:
                    act_data.append(copy.deepcopy(input.d))

        c1.backward(function_pre_hook=get_recomputed_data)
        # Make the order the same as `ref_data`.
        act_data.reverse()

        # Check recomputed data
        for act, ref in zip(act_data, ref_data):
            assert_allclose(act, ref, rtol=0, atol=0)
예제 #16
0
def sinusoidal_embedding(timesteps, embedding_dim):
    """
    Sinusoidal embeddings originally proposed in "Attention Is All You Need" (https://arxiv.org/abs/1706.03762).
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    denominator = -np.log(10000) / half_dim
    emb = F.exp(denominator * F.arange(start=0, stop=half_dim))
    emb = F.reshape(timesteps, (-1, 1)) * F.reshape(emb, (1, -1))
    emb = F.concatenate(F.cos(emb), F.sin(emb), axis=1)

    if embedding_dim & 1:  # zero pad to be divisible by two
        emb = F.pad(emb, [[0, 0], [0, 1]])

    assert emb.shape == (timesteps.shape[0], embedding_dim)

    return emb
예제 #17
0
def positional_encoding(x, N=6, include_input=True):
    """
    Args:
      x: Input (B, R, 3)
      N: Number of bands, N=6 for implicit network and N=4 for rendering network.
    """

    gamma = [x] if include_input else []
    bands = 2**np.arange(0, N + 1)
    data_holder = nn.Variable if isinstance(x, nn.Variable) else nn.NdArray
    bands = data_holder.from_numpy_array(bands)
    bands = F.reshape(bands, tuple([1] * x.ndim) + (N + 1, )) \
        * F.reshape(x, x.shape + (1, ))
    bands = F.reshape(bands, bands.shape[:-2] + (-1, ))
    cos_x = F.cos(bands)
    sin_x = F.sin(bands)

    gamma += [cos_x, sin_x]
    gamma = F.concatenate(*gamma, axis=-1)

    return gamma
예제 #18
0
def position_encoding(x: nn.Variable) -> nn.Variable:
    batch_size, sequence_length, dim = x.shape

    position = F.reshape(F.arange(0, sequence_length),
                         shape=(sequence_length, 1))
    # -> (sequence_length, 1)
    div_term = F.exp(F.arange(0, dim, 2) * -(np.log(10000.0) / dim))
    # -> (dim//2, )
    sin_val = F.sin(position * F.reshape(div_term, shape=(1, dim // 2)))
    # -> (sequence_length, dim//2)
    cos_val = F.cos(position * F.reshape(div_term, shape=(1, dim // 2)))
    # -> (sequence_length, dim//2)
    ret = []
    for i in range(dim):
        if i % 2 == 0:
            ret.append(sin_val[:, i // 2:i // 2 + 1])
        else:
            ret.append(cos_val[:, i // 2:i // 2 + 1])
    pe = F.reshape(F.concatenate(*ret, axis=1),
                   shape=(1, sequence_length, dim))
    return x + F.broadcast(pe, shape=x.shape)
예제 #19
0
def slerp(noise_1, noise_2, ratio):
    interpolated_noises = []
    for a, b in zip(noise_1, noise_2):
        a_norm = F.pow_scalar(F.sum(F.pow_scalar(a, 2), axis=1, keepdims=True),
                              0.5)
        b_norm = F.pow_scalar(F.sum(F.pow_scalar(b, 2), axis=1, keepdims=True),
                              0.5)

        a /= a_norm
        b /= b_norm

        d = F.sum(a * b, axis=1, keepdims=True)
        p = ratio * F.acos(d)
        c = b - d * a
        c_norm = F.pow_scalar(F.sum(F.pow_scalar(c, 2), axis=1, keepdims=True),
                              0.5)
        c /= c_norm

        d = a * F.cos(p) + c * F.sin(p)
        d = d / F.pow_scalar(F.sum(F.pow_scalar(d, 2), axis=1, keepdims=True),
                             0.5)

        interpolated_noises.append(d)
    return interpolated_noises
예제 #20
0
 def graph(x):
     y = F.sin(x).apply(recompute=True)
     y = F.cos(y)
     return y
예제 #21
0
    def __call__(self, x, test=False):

        fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop)
        x_theta = F.atan2(fft_imag, fft_real)

        x = Spectrogram(fft_real,
                        fft_imag,
                        power=self.power,
                        mono=(self.nb_channels == 1))

        nb_frames, nb_samples, nb_channels, nb_bins = x.shape

        mix_spec = F.identity(x)
        x = x[..., :self.nb_bins]

        # clone
        x_bass = F.identity(x)
        x_drums = F.identity(x)
        x_vocals = F.identity(x)
        x_other = F.identity(x)

        # shift and scale input to mean=0 std=1 (across all bins)
        x_bass += F.reshape(self.input_mean_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums += F.reshape(self.input_mean_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals += F.reshape(self.input_mean_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other += F.reshape(self.input_mean_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        x_bass *= F.reshape(self.input_scale_bass,
                            shape=(1, 1, 1, self.nb_bins),
                            inplace=False)
        x_drums *= F.reshape(self.input_scale_drums,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.input_scale_vocals,
                              shape=(1, 1, 1, self.nb_bins),
                              inplace=False)
        x_other *= F.reshape(self.input_scale_other,
                             shape=(1, 1, 1, self.nb_bins),
                             inplace=False)

        # encode and normalize every instance in a batch
        x_bass = self.fc_bn(x_bass,
                            self.hidden_size,
                            "fc1_bass",
                            test,
                            activation='tanh')
        x_drums = self.fc_bn(x_drums,
                             self.hidden_size,
                             "fc1_drums",
                             test,
                             activation='tanh')
        x_vocals = self.fc_bn(x_vocals,
                              self.hidden_size,
                              "fc1_vocals",
                              test,
                              activation='tanh')
        x_other = self.fc_bn(x_other,
                             self.hidden_size,
                             "fc1_other",
                             test,
                             activation='tanh')

        # Average the sources
        cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # apply 3-layers of stacked LSTM
        lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test)
        lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test)
        lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test)
        lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test)

        # lstm skip connection
        x_bass = F.concatenate(x_bass, lstm_out_bass)
        x_drums = F.concatenate(x_drums, lstm_out_drums)
        x_vocals = F.concatenate(x_vocals, lstm_out_vocals)
        x_other = F.concatenate(x_other, lstm_out_other)

        cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # first dense stage + batch norm
        x_bass = self.fc_bn(cross_2,
                            self.hidden_size,
                            "fc2_bass",
                            test,
                            activation='relu')
        x_drums = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_drums",
                             test,
                             activation='relu')
        x_vocals = self.fc_bn(cross_2,
                              self.hidden_size,
                              "fc2_vocals",
                              test,
                              activation='relu')
        x_other = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_other",
                             test,
                             activation='relu')

        # second dense stage + batch norm
        x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test)
        x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test)
        x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals",
                              test)
        x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test)

        # reshape back to original dim
        x_bass = F.reshape(
            x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_drums = F.reshape(
            x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_vocals = F.reshape(
            x_vocals,
            (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_other = F.reshape(
            x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))

        # apply output scaling
        x_bass *= F.reshape(self.output_scale_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums *= F.reshape(self.output_scale_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals *= F.reshape(self.output_scale_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other *= F.reshape(self.output_scale_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        x_bass += F.reshape(self.output_mean_bass,
                            shape=(1, 1, 1, self.nb_output_bins),
                            inplace=False)
        x_drums += F.reshape(self.output_mean_drums,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)
        x_vocals += F.reshape(self.output_mean_vocals,
                              shape=(1, 1, 1, self.nb_output_bins),
                              inplace=False)
        x_other += F.reshape(self.output_mean_other,
                             shape=(1, 1, 1, self.nb_output_bins),
                             inplace=False)

        # since our output is non-negative, we can apply RELU
        mask_bass = F.relu(x_bass)
        mask_drums = F.relu(x_drums)
        mask_vocals = F.relu(x_vocals)
        mask_other = F.relu(x_other)

        # (Frames, Bsize, Channels, Fbins)
        x_bass = mask_bass * mix_spec
        x_drums = mask_drums * mix_spec
        x_vocals = mask_vocals * mix_spec
        x_other = mask_other * mix_spec

        if not self.is_predict:
            tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0)
            # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames)
            tmp = F.transpose(tmp, (0, 2, 3, 4, 1))
            pred_r, pred_i = [], []
            for i in range(tmp.shape[0]):
                pred_r.append(tmp[i] * F.cos(x_theta))
                pred_i.append(tmp[i] * F.sin(x_theta))
            pred_r = F.stack(*pred_r, axis=0)
            pred_i = F.stack(*pred_i, axis=0)
            pred_r = F.reshape(pred_r,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred_i = F.reshape(pred_i,
                               (4 * nb_samples * nb_channels, 2049, nb_frames))
            pred = istft(pred_r,
                         pred_i,
                         self.n_fft,
                         self.n_hop,
                         self.n_fft,
                         window_type='hanning',
                         center=True)
            pred = F.reshape(pred, (4, nb_samples, nb_channels, -1))

        else:
            pred = None

        return mix_spec, F.concatenate(mask_bass,
                                       mask_drums,
                                       mask_vocals,
                                       mask_other,
                                       axis=2), pred
예제 #22
0
 def func2(x):
     assert x.recompute == f0
     y = F.sin(x)
     assert y.recompute == f2
     return y