def test_taps_error(self): # Test that an error rises if we use taps in outputs_info. with pytest.raises(RuntimeError): scan_checkpoints(lambda: None, [], { "initial": self.A, "taps": [-2] })
def setup_method(self): self.k = iscalar("k") self.A = vector("A") result, _ = scan( fn=lambda prior_result, A: prior_result * A, outputs_info=aet.ones_like(self.A), non_sequences=self.A, n_steps=self.k, ) result_check, _ = scan_checkpoints( fn=lambda prior_result, A: prior_result * A, outputs_info=aet.ones_like(self.A), non_sequences=self.A, n_steps=self.k, save_every_N=100, ) self.result = result[-1] self.result_check = result_check[-1] self.grad_A = aesara.grad(self.result.sum(), self.A) self.grad_A_check = aesara.grad(self.result_check.sum(), self.A)