def test_backward_indirectly( fix_seed, test_input, test_weight, test_bias, test_mode, expected_weight_grad, expected_input_grad, ): binarized_linear(test_input, test_weight, test_bias, test_mode).backward() assert torch.allclose( input=test_input.grad, other=expected_input_grad, rtol=1e-04, atol=1e-04, equal_nan=True, ) assert torch.allclose( input=test_weight.grad, other=expected_weight_grad, rtol=1e-04, atol=1e-04, equal_nan=True, )
def test_backward_indirectly( fix_seed, test_input, test_weight, test_bias, test_mode, expected_input_grad, expected_weight_grad, expected_bias_grad, ): binarized_linear(test_input, test_weight, test_bias, test_mode).backward() logger.debug(f"input grad: {test_input.grad}") logger.debug(f"expected input grad: {expected_input_grad}") logger.debug(f"weight grad: {test_weight.grad}") logger.debug(f"expected weight grad: {expected_weight_grad}") assert torch.allclose( input=test_input.grad, other=expected_input_grad, rtol=1e-04, atol=1e-04, equal_nan=True, ) assert torch.allclose( input=test_weight.grad, other=expected_weight_grad, rtol=1e-04, atol=1e-04, equal_nan=True, ) if expected_bias_grad: logger.debug(f"bias grad: {test_bias.grad}") logger.debug(f"expected bias grad: {expected_bias_grad}") assert torch.allclose( input=test_bias.grad, other=expected_bias_grad, rtol=1e-04, atol=1e-04, equal_nan=True, )
def test_forward(fix_seed, test_input, test_weight, test_bias, test_mode, expected): assert torch.allclose( input=binarized_linear(test_input, test_weight, test_bias, test_mode), other=expected, rtol=1e-04, atol=1e-04, equal_nan=True, )
def test_forward(fix_seed, test_input, test_weight, test_bias, test_mode, expected): answer = binarized_linear(test_input, test_weight, test_bias, test_mode) logger.debug(f"answer: {answer}") logger.debug(f"expected: {expected}") assert torch.allclose( input=answer, other=expected, rtol=1e-04, atol=1e-04, equal_nan=True, )
def test_supported_mode(fix_seed, test_input, test_weight, test_bias, test_mode): with pytest.raises(RuntimeError): binarized_linear(test_input, test_weight, test_bias, test_mode)
def forward(self, input: torch.Tensor) -> torch.Tensor: self.clipping() if self.bias is not None: return binarized_linear(input, self.weight, self.bias) return binarized_linear(input, self.weight)