def test_bernoulli_with_logits_overflow_gradient(init_tensor_type): p = Variable(init_tensor_type([1e40]), requires_grad=True) bernoulli = Bernoulli(logits=p) log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([1]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)
def test_bernoulli_underflow_gradient(init_tensor_type): p = Variable(init_tensor_type([0]), requires_grad=True) bernoulli = Bernoulli(sigmoid(p) * 0.0) log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)