Beispiel #1
0
def test_backward_indirectly(
    fix_seed,
    test_input,
    test_weight,
    test_bias,
    test_mode,
    expected_weight_grad,
    expected_input_grad,
):

    binarized_linear(test_input, test_weight, test_bias, test_mode).backward()

    assert torch.allclose(
        input=test_input.grad,
        other=expected_input_grad,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )

    assert torch.allclose(
        input=test_weight.grad,
        other=expected_weight_grad,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )
Beispiel #2
0
def test_backward_indirectly(
    fix_seed,
    test_input,
    test_weight,
    test_bias,
    test_mode,
    expected_input_grad,
    expected_weight_grad,
    expected_bias_grad,
):

    binarized_linear(test_input, test_weight, test_bias, test_mode).backward()

    logger.debug(f"input grad: {test_input.grad}")
    logger.debug(f"expected input grad: {expected_input_grad}")

    logger.debug(f"weight grad: {test_weight.grad}")
    logger.debug(f"expected weight grad: {expected_weight_grad}")

    assert torch.allclose(
        input=test_input.grad,
        other=expected_input_grad,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )

    assert torch.allclose(
        input=test_weight.grad,
        other=expected_weight_grad,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )

    if expected_bias_grad:
        logger.debug(f"bias grad: {test_bias.grad}")
        logger.debug(f"expected bias grad: {expected_bias_grad}")

        assert torch.allclose(
            input=test_bias.grad,
            other=expected_bias_grad,
            rtol=1e-04,
            atol=1e-04,
            equal_nan=True,
        )
Beispiel #3
0
def test_forward(fix_seed, test_input, test_weight, test_bias, test_mode,
                 expected):
    assert torch.allclose(
        input=binarized_linear(test_input, test_weight, test_bias, test_mode),
        other=expected,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )
Beispiel #4
0
def test_forward(fix_seed, test_input, test_weight, test_bias, test_mode,
                 expected):

    answer = binarized_linear(test_input, test_weight, test_bias, test_mode)

    logger.debug(f"answer: {answer}")
    logger.debug(f"expected: {expected}")
    assert torch.allclose(
        input=answer,
        other=expected,
        rtol=1e-04,
        atol=1e-04,
        equal_nan=True,
    )
Beispiel #5
0
def test_supported_mode(fix_seed, test_input, test_weight, test_bias,
                        test_mode):
    with pytest.raises(RuntimeError):
        binarized_linear(test_input, test_weight, test_bias, test_mode)
Beispiel #6
0
 def forward(self, input: torch.Tensor) -> torch.Tensor:
     self.clipping()
     if self.bias is not None:
         return binarized_linear(input, self.weight, self.bias)
     return binarized_linear(input, self.weight)