예제 #1
0
 def get_reg_deriv(self):
     a, b = self.get_params()
     dkl_a = (self.alpha / a**2) * (-0.577215664901532 - torch.digamma(b) -
                                    1. / b) + a
     dkl_a = dkl_a * torch.sigmoid(self.a_uc)
     dkl_b = (1 - self.alpha / a) * (
         1. / b**2 - torch.polygamma(1, b)) + b - (1. / b**2)
     dkl_b = dkl_b * torch.sigmoid(self.a_uc)
     derivs_kl = torch.stack([dkl_a, dkl_b]).squeeze_()
     return derivs_kl
예제 #2
0
파일: gamma.py 프로젝트: ceisenach/torchkit
 def kl_divergence(self,states,alpha0,beta0):
     """
     KL(pi_old || pi_new). alpha0, beta0 == pi_old
     """
     alpha1, beta1 = self._net(states.detach())
     I = torch.sum(alpha1*(torch.log(beta0) - torch.log(beta1)),dim=1,keepdim=True)
     II = torch.sum((alpha0-alpha1)*torch.polygamma(0,alpha0),dim=1,keepdim=True)
     III = torch.sum((beta1-beta0)*(alpha0/beta0),dim=1,keepdim=True)
     IV = torch.sum(torch.lgamma(alpha1) - torch.lgamma(alpha0),dim=1,keepdim=True)
     kl = I+II+III+IV
     return kl
예제 #3
0
파일: gamma.py 프로젝트: ceisenach/torchkit
    def fisher_information_params(self,alpha,beta,backend='pytorch'):
        I_11 = torch.polygamma(1,alpha.data)
        I_12 = -1./beta.data
        I_22 = alpha*(I_12**2)

        if backend == 'pytorch':
            pass
        elif backend == 'numpy':        
            I_11 = I_11.numpy()
            I_12 = I_12.numpy()
            I_22 = I_22.numpy()
        else:
            raise RuntimeError()

        return I_11,I_12,I_22
예제 #4
0
def get_Hinv_g(gamma: torch.Tensor, eta: torch.Tensor,
               alpha: torch.Tensor) -> torch.Tensor:
    """
    """
    log_theta_tilde = expected_log_dirichlet(concentration=gamma)  #(C, K)

    # calculate first derivative
    g = torch.matmul(input=eta.T, other=log_theta_tilde)  # (L, K)
    g = g - expected_log_dirichlet(concentration=alpha) * torch.sum(
        input=eta, dim=0)[:, None]  # (L, K)

    # Q = diag_embed(L, K) = (L, K, K). Here, Q = (L, K)
    Q = -torch.sum(input=eta, dim=0)[:, None] * torch.polygamma(input=alpha,
                                                                n=1)  # (L, K)

    u = torch.sum(input=eta, dim=0) * torch.polygamma(
        input=torch.sum(input=alpha, dim=-1), n=1)  # (L, )

    b = torch.sum(input=g * Q, dim=-1) / (
        1 / u + torch.sum(input=1 / Q, dim=-1))  # (L, )

    Hinv_g = (g - b[:, None]) / Q  # (L, K)

    return Hinv_g
예제 #5
0
    def _update(self, ss0, ss1, ss2):
        """ Update CMM parameters.

        Args:
            ss0 (torch.tensor): sum(Z)          (K)
            ss1 (torch.tensor): sum(Z * log(X)) (K)
            ss2 (torch.tensor): sum(Z * X**2)   (K)

        """
        K = len(ss0)
        tiny = 1e-32

        def log(x):
            return (x + tiny).log_()

        # Update parameters (using means and variances)
        for k in range(K):

            max_iter = 10000 if self.update_dof else 1
            for i in range(max_iter):
                # print(self.sig[k].item(), self.dof[k].item())
                if self.update_sig:
                    # Closed form update of sig
                    self.sig[k] = (ss2[k] / (self.dof[k] * ss0[k])).sqrt()

                if self.update_dof:
                    # Gauss-Newton update of dof
                    gkl = torch.digamma(self.dof[k] / 2) + log(
                        2 * self.sig[k].square())
                    gkl = 0.5 * ss0[k] * gkl - ss1[k]  # gradient w.r.t. dof
                    hkl = 0.25 * ss0[k] * torch.polygamma(
                        1, self.dof[k] / 2)  # Hessian w.r.t. dof
                    self.dof[k].sub_(gkl / hkl)  # G-N update
                    if self.dof[k] < 2:
                        self.dof[k] = 2
                        break
                    if gkl * gkl < 1e-9:
                        break
예제 #6
0
    def test_unary_propagate_names_fns(self):
        def _test(testcase, names=('N', 'D'), device='cpu'):
            sizes = [2] * len(names)
            tensor = torch.empty(sizes, names=names, device=device)
            out = testcase.lambd(tensor)
            self.assertEqual(out.names, tensor.names,
                             message=testcase.name)

        def fn(name, *args, **kwargs):
            return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]

        def method(name, *args, **kwargs):
            return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]

        def out_function(name, *args, **kwargs):
            out_fn = getattr(torch, name)

            def fn(tensor):
                result = tensor.new_empty([0])
                out_fn(tensor, *args, out=result, **kwargs)
                return result

            return [Function(name + '_out', fn)]

        def fn_method_and_inplace(name, *args, **kwargs):
            return (
                method(name, *args, **kwargs) +
                method(name + '_', *args, **kwargs) +
                out_function(name, *args, **kwargs)
            )

        # All of these operate on 2x2 tensors.
        tests = [
            # unary pointwise
            fn_method_and_inplace('abs'),
            fn_method_and_inplace('acos'),
            fn_method_and_inplace('asin'),
            fn_method_and_inplace('atan'),
            fn_method_and_inplace('ceil'),
            fn_method_and_inplace('clamp', -1, 1),
            fn_method_and_inplace('clamp_min', -2),
            fn_method_and_inplace('clamp_max', 2),
            method('cauchy_'),
            method('clone'),
            method('contiguous'),
            fn_method_and_inplace('cos'),
            fn_method_and_inplace('cosh'),
            fn_method_and_inplace('digamma'),
            fn_method_and_inplace('erf'),
            fn_method_and_inplace('erfc'),
            fn_method_and_inplace('erfinv'),
            fn_method_and_inplace('exp'),
            fn_method_and_inplace('expm1'),
            method('exponential_'),
            fn_method_and_inplace('floor'),
            fn_method_and_inplace('frac'),
            method('geometric_', p=0.5),
            fn_method_and_inplace('lgamma'),
            fn_method_and_inplace('log'),
            fn_method_and_inplace('log10'),
            fn_method_and_inplace('log1p'),
            fn_method_and_inplace('log2'),
            method('log_normal_'),
            fn_method_and_inplace('neg'),
            method('normal_'),
            [Function('polygamma', lambda t: torch.polygamma(1, t))],
            method('polygamma_', 1),
            fn_method_and_inplace('reciprocal'),
            method('random_', 0, 1),
            method('random_', 1),
            method('random_'),
            fn_method_and_inplace('round'),
            fn_method_and_inplace('rsqrt'),
            fn_method_and_inplace('sigmoid'),
            fn_method_and_inplace('sign'),
            fn_method_and_inplace('sin'),
            fn_method_and_inplace('sinh'),
            fn_method_and_inplace('sqrt'),
            fn_method_and_inplace('tan'),
            fn_method_and_inplace('tanh'),
            fn_method_and_inplace('trunc'),
            method('uniform_'),
            method('zero_'),
            method('fill_', 1),
            method('fill_', torch.tensor(3.14)),

            # conversions
            method('to', dtype=torch.long),
            method('to', device='cpu'),
            method('to', torch.empty([])),
            method('bool'),
            method('byte'),
            method('char'),
            method('cpu'),
            method('double'),
            method('float'),
            method('long'),
            method('half'),
            method('int'),
            method('short'),
            method('type', dtype=torch.long),

            # views
            method('narrow', 0, 0, 1),

            # creation functions
            fn('empty_like'),

            # bernoulli variants
            method('bernoulli_', 0.5),
            method('bernoulli_', torch.tensor(0.5)),

            [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))],
            [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))],
        ]
        tests = flatten(tests)

        for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()):
            _test(testcase, device=device)
예제 #7
0
파일: math_ops.py 프로젝트: malfet/pytorch
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
예제 #8
0
def mvtrigamma(x: torch.Tensor, p: int) -> torch.Tensor:
    i = torch.arange(p, dtype=x.dtype, device=x.device)
    return torch.polygamma(1, x.unsqueeze(-1) - .5 * i).sum(-1)
예제 #9
0
    def test_unary_fns(self):
        TestCase = namedtuple('TestCase', ['name', 'lambd'])

        def _test(testcase, names=('N', 'D'), device='cpu'):
            sizes = [2] * len(names)
            tensor = torch.empty(sizes, names=names, device=device)
            out = testcase.lambd(tensor)
            self.assertEqual(out.names, tensor.names, message=testcase.name)

        def method(name, *args, **kwargs):
            return [
                TestCase(name, lambda t: getattr(t, name)(*args, **kwargs))
            ]

        def out_function(name, *args, **kwargs):
            out_fn = getattr(torch, name)

            def fn(tensor):
                result = tensor.new_empty([0])
                out_fn(tensor, *args, out=result, **kwargs)
                return result

            return [TestCase(name + '_out', fn)]

        def fn_method_and_inplace(name, *args, **kwargs):
            return (method(name, *args, **kwargs) +
                    method(name + '_', *args, **kwargs) +
                    out_function(name, *args, **kwargs))

        def flatten(lst):
            return [item for sublist in lst for item in sublist]

        tests = [
            fn_method_and_inplace('abs'),
            fn_method_and_inplace('acos'),
            fn_method_and_inplace('asin'),
            fn_method_and_inplace('atan'),
            fn_method_and_inplace('ceil'),
            fn_method_and_inplace('clamp', -1, 1),
            fn_method_and_inplace('clamp_min', -2),
            fn_method_and_inplace('clamp_max', 2),
            fn_method_and_inplace('cos'),
            fn_method_and_inplace('cosh'),
            fn_method_and_inplace('digamma'),
            fn_method_and_inplace('erf'),
            fn_method_and_inplace('erfc'),
            fn_method_and_inplace('erfinv'),
            fn_method_and_inplace('exp'),
            fn_method_and_inplace('expm1'),
            fn_method_and_inplace('floor'),
            fn_method_and_inplace('frac'),
            fn_method_and_inplace('lgamma'),
            fn_method_and_inplace('log'),
            fn_method_and_inplace('log10'),
            fn_method_and_inplace('log1p'),
            fn_method_and_inplace('log2'),
            fn_method_and_inplace('neg'),
            [TestCase('polygamma', lambda t: torch.polygamma(1, t))],
            method('polygamma_', 1),
            fn_method_and_inplace('reciprocal'),
            fn_method_and_inplace('round'),
            fn_method_and_inplace('rsqrt'),
            fn_method_and_inplace('sigmoid'),
            fn_method_and_inplace('sin'),
            fn_method_and_inplace('sinh'),
            fn_method_and_inplace('sqrt'),
            fn_method_and_inplace('tan'),
            fn_method_and_inplace('tanh'),
            fn_method_and_inplace('trunc'),
            method('zero_'),
            method('fill_', 1),
            method('fill_', torch.tensor(3.14)),
        ]
        tests = flatten(tests)

        for testcase, device in itertools.product(
                tests, torch.testing.get_all_device_types()):
            _test(testcase, device=device)
예제 #10
0
    # print(min(gamma_sample), max(gamma_sample))
    # plt.hist(gamma_sample.numpy(), normed=True, bins=np.arange(min(gamma_sample), max(gamma_sample) + bw, bw))
    # plt.title('alpha :%6.4f, beta :%6.4f' % (concentration[0], rate[0]))
    # plt.show()

    ## Gradient Correction checking code
    n_sample = 100000
    n_batch = 20
    eps = 1e-4
    concentration0 = (torch.exp(torch.randn(n_batch))) * 0 + torch.rand(1) * 2
    concentration1 = concentration0 + eps
    rate = torch.exp(torch.randn(n_batch)) * 0 + torch.rand(1) * 2
    func = lambda x: torch.log(x)
    print('For logarithm gradient should be polygamma(1) x concentraion')
    print('Exact gradient      : ' + ' '.join([
        '%+6.4f' % (torch.polygamma(1, concentration0) * concentration0)[i]
        for i in range(n_batch)
    ]))
    # func = lambda x: x
    # print('For identity gradient should be 1.0 / rate x concentraion')
    # print('Exact gradient      : ' + ' '.join(['%+6.4f' % (1.0 / rate * concentration0)[i] for i in range(n_batch)]))
    # func = lambda x: x ** 2
    # print('For square gradient should be (1.0 + 2.0 concentration) / rate ** 2 x concentraion')
    # print('Exact gradient      : ' + ' '.join(['%+6.4f' % ((1.0 + 2.0 * concentration0) / rate ** 2 * concentration0)[i] for i in range(n_batch)]))
    reparam_module0 = GammaReparametrizedSample(torch.Size([n_batch]))
    reparam_module0.log_shape.data = torch.log(concentration0)
    reparam_module0.log_rate.data = torch.log(rate)
    gamma_sample0 = reparam_module0(n_sample)
    func_val0 = func(gamma_sample0)
    reparam_module1 = GammaReparametrizedSample(torch.Size([n_batch]))
    reparam_module1.log_shape.data = torch.log(concentration1)
예제 #11
0
 def get_reg_deriv_b(self):
     a, b = self.get_params()
     dkl_b = (1 - self.alpha / a) * (
         1. / b**2 - torch.polygamma(1, b)) + b - (1. / b**2)
     return dkl_b * torch.sigmoid(self.b_uc)
예제 #12
0
# nan_to_num
w = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x)
torch.nan_to_num(x, nan=2.0)
torch.nan_to_num(x, nan=2.0, posinf=1.0)

# neg/negative
torch.neg(torch.randn(5))

# nextafter
eps = torch.finfo(torch.float32).eps
torch.nextafter(torch.tensor([1, 2]),
                torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps])

# polygamma
torch.polygamma(1, torch.tensor([1, 0.5]))
torch.polygamma(2, torch.tensor([1, 0.5]))
torch.polygamma(3, torch.tensor([1, 0.5]))
torch.polygamma(4, torch.tensor([1, 0.5]))

# pow
torch.pow(a, 2)
torch.pow(torch.arange(1., 5.), torch.arange(1., 5.))

# rad2deg
torch.rad2deg(torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570,
                                                               -1.570]]))

# real
torch.randn(4, dtype=torch.cfloat).real
예제 #13
0
    # bw = torch.max(gamma_sample) / 200
    # print(min(gamma_sample), max(gamma_sample))
    # plt.hist(gamma_sample.numpy(), normed=True, bins=np.arange(min(gamma_sample), max(gamma_sample) + bw, bw))
    # plt.title('alpha :%6.4f, beta :%6.4f' % (concentration[0], rate[0]))
    # plt.show()

    ## Gradient Correction checking code
    n_sample = 100000
    n_batch = 20
    eps = 1e-4
    concentration0 = (torch.exp(torch.randn(n_batch))) * 0 + torch.rand(1) * 2
    concentration1 = concentration0 + eps
    rate = torch.exp(torch.randn(n_batch)) * 0 + torch.rand(1) * 2
    func = lambda x: torch.log(x)
    print('For logarithm gradient should be polygamma(1) x concentraion')
    print('Exact gradient      : ' + ' '.join(['%+6.4f' % (torch.polygamma(1, concentration0) * concentration0)[i] for i in range(n_batch)]))
    # func = lambda x: x
    # print('For identity gradient should be 1.0 / rate x concentraion')
    # print('Exact gradient      : ' + ' '.join(['%+6.4f' % (1.0 / rate * concentration0)[i] for i in range(n_batch)]))
    # func = lambda x: x ** 2
    # print('For square gradient should be (1.0 + 2.0 concentration) / rate ** 2 x concentraion')
    # print('Exact gradient      : ' + ' '.join(['%+6.4f' % ((1.0 + 2.0 * concentration0) / rate ** 2 * concentration0)[i] for i in range(n_batch)]))
    reparam_module0 = GammaReparametrizedSample(torch.Size([n_batch]))
    reparam_module0.log_shape.data = torch.log(concentration0)
    reparam_module0.log_rate.data = torch.log(rate)
    gamma_sample0 = reparam_module0(n_sample)
    func_val0 = func(gamma_sample0)
    reparam_module1 = GammaReparametrizedSample(torch.Size([n_batch]))
    reparam_module1.log_shape.data = torch.log(concentration1)
    reparam_module1.log_rate.data = torch.log(rate)
    gamma_sample1 = reparam_module1(n_sample)