def get_reg_deriv(self): a, b = self.get_params() dkl_a = (self.alpha / a**2) * (-0.577215664901532 - torch.digamma(b) - 1. / b) + a dkl_a = dkl_a * torch.sigmoid(self.a_uc) dkl_b = (1 - self.alpha / a) * ( 1. / b**2 - torch.polygamma(1, b)) + b - (1. / b**2) dkl_b = dkl_b * torch.sigmoid(self.a_uc) derivs_kl = torch.stack([dkl_a, dkl_b]).squeeze_() return derivs_kl
def kl_divergence(self,states,alpha0,beta0): """ KL(pi_old || pi_new). alpha0, beta0 == pi_old """ alpha1, beta1 = self._net(states.detach()) I = torch.sum(alpha1*(torch.log(beta0) - torch.log(beta1)),dim=1,keepdim=True) II = torch.sum((alpha0-alpha1)*torch.polygamma(0,alpha0),dim=1,keepdim=True) III = torch.sum((beta1-beta0)*(alpha0/beta0),dim=1,keepdim=True) IV = torch.sum(torch.lgamma(alpha1) - torch.lgamma(alpha0),dim=1,keepdim=True) kl = I+II+III+IV return kl
def fisher_information_params(self,alpha,beta,backend='pytorch'): I_11 = torch.polygamma(1,alpha.data) I_12 = -1./beta.data I_22 = alpha*(I_12**2) if backend == 'pytorch': pass elif backend == 'numpy': I_11 = I_11.numpy() I_12 = I_12.numpy() I_22 = I_22.numpy() else: raise RuntimeError() return I_11,I_12,I_22
def get_Hinv_g(gamma: torch.Tensor, eta: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: """ """ log_theta_tilde = expected_log_dirichlet(concentration=gamma) #(C, K) # calculate first derivative g = torch.matmul(input=eta.T, other=log_theta_tilde) # (L, K) g = g - expected_log_dirichlet(concentration=alpha) * torch.sum( input=eta, dim=0)[:, None] # (L, K) # Q = diag_embed(L, K) = (L, K, K). Here, Q = (L, K) Q = -torch.sum(input=eta, dim=0)[:, None] * torch.polygamma(input=alpha, n=1) # (L, K) u = torch.sum(input=eta, dim=0) * torch.polygamma( input=torch.sum(input=alpha, dim=-1), n=1) # (L, ) b = torch.sum(input=g * Q, dim=-1) / ( 1 / u + torch.sum(input=1 / Q, dim=-1)) # (L, ) Hinv_g = (g - b[:, None]) / Q # (L, K) return Hinv_g
def _update(self, ss0, ss1, ss2): """ Update CMM parameters. Args: ss0 (torch.tensor): sum(Z) (K) ss1 (torch.tensor): sum(Z * log(X)) (K) ss2 (torch.tensor): sum(Z * X**2) (K) """ K = len(ss0) tiny = 1e-32 def log(x): return (x + tiny).log_() # Update parameters (using means and variances) for k in range(K): max_iter = 10000 if self.update_dof else 1 for i in range(max_iter): # print(self.sig[k].item(), self.dof[k].item()) if self.update_sig: # Closed form update of sig self.sig[k] = (ss2[k] / (self.dof[k] * ss0[k])).sqrt() if self.update_dof: # Gauss-Newton update of dof gkl = torch.digamma(self.dof[k] / 2) + log( 2 * self.sig[k].square()) gkl = 0.5 * ss0[k] * gkl - ss1[k] # gradient w.r.t. dof hkl = 0.25 * ss0[k] * torch.polygamma( 1, self.dof[k] / 2) # Hessian w.r.t. dof self.dof[k].sub_(gkl / hkl) # G-N update if self.dof[k] < 2: self.dof[k] = 2 break if gkl * gkl < 1e-9: break
def test_unary_propagate_names_fns(self): def _test(testcase, names=('N', 'D'), device='cpu'): sizes = [2] * len(names) tensor = torch.empty(sizes, names=names, device=device) out = testcase.lambd(tensor) self.assertEqual(out.names, tensor.names, message=testcase.name) def fn(name, *args, **kwargs): return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))] def method(name, *args, **kwargs): return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))] def out_function(name, *args, **kwargs): out_fn = getattr(torch, name) def fn(tensor): result = tensor.new_empty([0]) out_fn(tensor, *args, out=result, **kwargs) return result return [Function(name + '_out', fn)] def fn_method_and_inplace(name, *args, **kwargs): return ( method(name, *args, **kwargs) + method(name + '_', *args, **kwargs) + out_function(name, *args, **kwargs) ) # All of these operate on 2x2 tensors. tests = [ # unary pointwise fn_method_and_inplace('abs'), fn_method_and_inplace('acos'), fn_method_and_inplace('asin'), fn_method_and_inplace('atan'), fn_method_and_inplace('ceil'), fn_method_and_inplace('clamp', -1, 1), fn_method_and_inplace('clamp_min', -2), fn_method_and_inplace('clamp_max', 2), method('cauchy_'), method('clone'), method('contiguous'), fn_method_and_inplace('cos'), fn_method_and_inplace('cosh'), fn_method_and_inplace('digamma'), fn_method_and_inplace('erf'), fn_method_and_inplace('erfc'), fn_method_and_inplace('erfinv'), fn_method_and_inplace('exp'), fn_method_and_inplace('expm1'), method('exponential_'), fn_method_and_inplace('floor'), fn_method_and_inplace('frac'), method('geometric_', p=0.5), fn_method_and_inplace('lgamma'), fn_method_and_inplace('log'), fn_method_and_inplace('log10'), fn_method_and_inplace('log1p'), fn_method_and_inplace('log2'), method('log_normal_'), fn_method_and_inplace('neg'), method('normal_'), [Function('polygamma', lambda t: torch.polygamma(1, t))], method('polygamma_', 1), fn_method_and_inplace('reciprocal'), method('random_', 0, 1), method('random_', 1), method('random_'), fn_method_and_inplace('round'), fn_method_and_inplace('rsqrt'), fn_method_and_inplace('sigmoid'), fn_method_and_inplace('sign'), fn_method_and_inplace('sin'), fn_method_and_inplace('sinh'), fn_method_and_inplace('sqrt'), fn_method_and_inplace('tan'), fn_method_and_inplace('tanh'), fn_method_and_inplace('trunc'), method('uniform_'), method('zero_'), method('fill_', 1), method('fill_', torch.tensor(3.14)), # conversions method('to', dtype=torch.long), method('to', device='cpu'), method('to', torch.empty([])), method('bool'), method('byte'), method('char'), method('cpu'), method('double'), method('float'), method('long'), method('half'), method('int'), method('short'), method('type', dtype=torch.long), # views method('narrow', 0, 0, 1), # creation functions fn('empty_like'), # bernoulli variants method('bernoulli_', 0.5), method('bernoulli_', torch.tensor(0.5)), [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))], [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))], ] tests = flatten(tests) for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): _test(testcase, device=device)
def pointwise_ops(self): a = torch.randn(4) b = torch.randn(4) t = torch.tensor([-1, -2, 3], dtype=torch.int8) r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) t = torch.tensor([-1, -2, 3], dtype=torch.int8) s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) f = torch.zeros(3) g = torch.tensor([-1, 0, 1]) w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) return ( torch.abs(torch.tensor([-1, -2, 3])), torch.absolute(torch.tensor([-1, -2, 3])), torch.acos(a), torch.arccos(a), torch.acosh(a.uniform_(1.0, 2.0)), torch.add(a, 20), torch.add(a, torch.randn(4, 1), alpha=10), torch.addcdiv(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.addcmul(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.angle(a), torch.asin(a), torch.arcsin(a), torch.asinh(a), torch.arcsinh(a), torch.atan(a), torch.arctan(a), torch.atanh(a.uniform_(-1.0, 1.0)), torch.arctanh(a.uniform_(-1.0, 1.0)), torch.atan2(a, a), torch.bitwise_not(t), torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.ceil(a), torch.clamp(a, min=-0.5, max=0.5), torch.clamp(a, min=0.5), torch.clamp(a, max=0.5), torch.clip(a, min=-0.5, max=0.5), torch.conj(a), torch.copysign(a, 1), torch.copysign(a, b), torch.cos(a), torch.cosh(a), torch.deg2rad( torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])), torch.div(a, b), torch.divide(a, b, rounding_mode="trunc"), torch.divide(a, b, rounding_mode="floor"), torch.digamma(torch.tensor([1.0, 0.5])), torch.erf(torch.tensor([0.0, -1.0, 10.0])), torch.erfc(torch.tensor([0.0, -1.0, 10.0])), torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), torch.exp(torch.tensor([0.0, math.log(2.0)])), torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), torch.expm1(torch.tensor([0.0, math.log(2.0)])), torch.fake_quantize_per_channel_affine( torch.randn(2, 2, 2), (torch.randn(2) + 1) * 0.05, torch.zeros(2), 1, 0, 255, ), torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), torch.float_power(torch.randint(10, (4, )), 2), torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), torch.floor(a), # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.frac(torch.tensor([1.0, 2.5, -3.2])), torch.randn(4, dtype=torch.cfloat).imag, torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), torch.lerp( torch.arange(1.0, 5.0), torch.empty(4).fill_(10), torch.full_like(torch.arange(1.0, 5.0), 0.5), ), torch.lgamma(torch.arange(0.5, 2, 0.5)), torch.log(torch.arange(5) + 10), torch.log10(torch.rand(5)), torch.log1p(torch.randn(5)), torch.log2(torch.rand(5)), torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logical_and(r, s), torch.logical_and(r.double(), s.double()), torch.logical_and(r.double(), s), torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), torch.logical_not( torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), torch.logical_not( torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), out=torch.empty(3, dtype=torch.int16), ), torch.logical_or(r, s), torch.logical_or(r.double(), s.double()), torch.logical_or(r.double(), s), torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_xor(r, s), torch.logical_xor(r.double(), s.double()), torch.logical_xor(r.double(), s), torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logit(torch.rand(5), eps=1e-6), torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), torch.i0(torch.arange(5, dtype=torch.float32)), torch.igamma(a, b), torch.igammac(a, b), torch.mul(torch.randn(3), 100), torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), torch.nan_to_num(w), torch.nan_to_num(w, nan=2.0), torch.nan_to_num(w, nan=2.0, posinf=1.0), torch.neg(torch.randn(5)), # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), torch.polygamma(1, torch.tensor([1.0, 0.5])), torch.polygamma(2, torch.tensor([1.0, 0.5])), torch.polygamma(3, torch.tensor([1.0, 0.5])), torch.polygamma(4, torch.tensor([1.0, 0.5])), torch.pow(a, 2), torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), torch.rad2deg( torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])), torch.randn(4, dtype=torch.cfloat).real, torch.reciprocal(a), torch.remainder(torch.tensor([-3.0, -2.0]), 2), torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.round(a), torch.rsqrt(a), torch.sigmoid(a), torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sgn(a), torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sin(a), torch.sinc(a), torch.sinh(a), torch.sqrt(a), torch.square(a), torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), torch.tan(a), torch.tanh(a), torch.trunc(a), torch.xlogy(f, g), torch.xlogy(f, g), torch.xlogy(f, 4), torch.xlogy(2, g), )
def mvtrigamma(x: torch.Tensor, p: int) -> torch.Tensor: i = torch.arange(p, dtype=x.dtype, device=x.device) return torch.polygamma(1, x.unsqueeze(-1) - .5 * i).sum(-1)
def test_unary_fns(self): TestCase = namedtuple('TestCase', ['name', 'lambd']) def _test(testcase, names=('N', 'D'), device='cpu'): sizes = [2] * len(names) tensor = torch.empty(sizes, names=names, device=device) out = testcase.lambd(tensor) self.assertEqual(out.names, tensor.names, message=testcase.name) def method(name, *args, **kwargs): return [ TestCase(name, lambda t: getattr(t, name)(*args, **kwargs)) ] def out_function(name, *args, **kwargs): out_fn = getattr(torch, name) def fn(tensor): result = tensor.new_empty([0]) out_fn(tensor, *args, out=result, **kwargs) return result return [TestCase(name + '_out', fn)] def fn_method_and_inplace(name, *args, **kwargs): return (method(name, *args, **kwargs) + method(name + '_', *args, **kwargs) + out_function(name, *args, **kwargs)) def flatten(lst): return [item for sublist in lst for item in sublist] tests = [ fn_method_and_inplace('abs'), fn_method_and_inplace('acos'), fn_method_and_inplace('asin'), fn_method_and_inplace('atan'), fn_method_and_inplace('ceil'), fn_method_and_inplace('clamp', -1, 1), fn_method_and_inplace('clamp_min', -2), fn_method_and_inplace('clamp_max', 2), fn_method_and_inplace('cos'), fn_method_and_inplace('cosh'), fn_method_and_inplace('digamma'), fn_method_and_inplace('erf'), fn_method_and_inplace('erfc'), fn_method_and_inplace('erfinv'), fn_method_and_inplace('exp'), fn_method_and_inplace('expm1'), fn_method_and_inplace('floor'), fn_method_and_inplace('frac'), fn_method_and_inplace('lgamma'), fn_method_and_inplace('log'), fn_method_and_inplace('log10'), fn_method_and_inplace('log1p'), fn_method_and_inplace('log2'), fn_method_and_inplace('neg'), [TestCase('polygamma', lambda t: torch.polygamma(1, t))], method('polygamma_', 1), fn_method_and_inplace('reciprocal'), fn_method_and_inplace('round'), fn_method_and_inplace('rsqrt'), fn_method_and_inplace('sigmoid'), fn_method_and_inplace('sin'), fn_method_and_inplace('sinh'), fn_method_and_inplace('sqrt'), fn_method_and_inplace('tan'), fn_method_and_inplace('tanh'), fn_method_and_inplace('trunc'), method('zero_'), method('fill_', 1), method('fill_', torch.tensor(3.14)), ] tests = flatten(tests) for testcase, device in itertools.product( tests, torch.testing.get_all_device_types()): _test(testcase, device=device)
# print(min(gamma_sample), max(gamma_sample)) # plt.hist(gamma_sample.numpy(), normed=True, bins=np.arange(min(gamma_sample), max(gamma_sample) + bw, bw)) # plt.title('alpha :%6.4f, beta :%6.4f' % (concentration[0], rate[0])) # plt.show() ## Gradient Correction checking code n_sample = 100000 n_batch = 20 eps = 1e-4 concentration0 = (torch.exp(torch.randn(n_batch))) * 0 + torch.rand(1) * 2 concentration1 = concentration0 + eps rate = torch.exp(torch.randn(n_batch)) * 0 + torch.rand(1) * 2 func = lambda x: torch.log(x) print('For logarithm gradient should be polygamma(1) x concentraion') print('Exact gradient : ' + ' '.join([ '%+6.4f' % (torch.polygamma(1, concentration0) * concentration0)[i] for i in range(n_batch) ])) # func = lambda x: x # print('For identity gradient should be 1.0 / rate x concentraion') # print('Exact gradient : ' + ' '.join(['%+6.4f' % (1.0 / rate * concentration0)[i] for i in range(n_batch)])) # func = lambda x: x ** 2 # print('For square gradient should be (1.0 + 2.0 concentration) / rate ** 2 x concentraion') # print('Exact gradient : ' + ' '.join(['%+6.4f' % ((1.0 + 2.0 * concentration0) / rate ** 2 * concentration0)[i] for i in range(n_batch)])) reparam_module0 = GammaReparametrizedSample(torch.Size([n_batch])) reparam_module0.log_shape.data = torch.log(concentration0) reparam_module0.log_rate.data = torch.log(rate) gamma_sample0 = reparam_module0(n_sample) func_val0 = func(gamma_sample0) reparam_module1 = GammaReparametrizedSample(torch.Size([n_batch])) reparam_module1.log_shape.data = torch.log(concentration1)
def get_reg_deriv_b(self): a, b = self.get_params() dkl_b = (1 - self.alpha / a) * ( 1. / b**2 - torch.polygamma(1, b)) + b - (1. / b**2) return dkl_b * torch.sigmoid(self.b_uc)
# nan_to_num w = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) torch.nan_to_num(x) torch.nan_to_num(x, nan=2.0) torch.nan_to_num(x, nan=2.0, posinf=1.0) # neg/negative torch.neg(torch.randn(5)) # nextafter eps = torch.finfo(torch.float32).eps torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]) # polygamma torch.polygamma(1, torch.tensor([1, 0.5])) torch.polygamma(2, torch.tensor([1, 0.5])) torch.polygamma(3, torch.tensor([1, 0.5])) torch.polygamma(4, torch.tensor([1, 0.5])) # pow torch.pow(a, 2) torch.pow(torch.arange(1., 5.), torch.arange(1., 5.)) # rad2deg torch.rad2deg(torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])) # real torch.randn(4, dtype=torch.cfloat).real
# bw = torch.max(gamma_sample) / 200 # print(min(gamma_sample), max(gamma_sample)) # plt.hist(gamma_sample.numpy(), normed=True, bins=np.arange(min(gamma_sample), max(gamma_sample) + bw, bw)) # plt.title('alpha :%6.4f, beta :%6.4f' % (concentration[0], rate[0])) # plt.show() ## Gradient Correction checking code n_sample = 100000 n_batch = 20 eps = 1e-4 concentration0 = (torch.exp(torch.randn(n_batch))) * 0 + torch.rand(1) * 2 concentration1 = concentration0 + eps rate = torch.exp(torch.randn(n_batch)) * 0 + torch.rand(1) * 2 func = lambda x: torch.log(x) print('For logarithm gradient should be polygamma(1) x concentraion') print('Exact gradient : ' + ' '.join(['%+6.4f' % (torch.polygamma(1, concentration0) * concentration0)[i] for i in range(n_batch)])) # func = lambda x: x # print('For identity gradient should be 1.0 / rate x concentraion') # print('Exact gradient : ' + ' '.join(['%+6.4f' % (1.0 / rate * concentration0)[i] for i in range(n_batch)])) # func = lambda x: x ** 2 # print('For square gradient should be (1.0 + 2.0 concentration) / rate ** 2 x concentraion') # print('Exact gradient : ' + ' '.join(['%+6.4f' % ((1.0 + 2.0 * concentration0) / rate ** 2 * concentration0)[i] for i in range(n_batch)])) reparam_module0 = GammaReparametrizedSample(torch.Size([n_batch])) reparam_module0.log_shape.data = torch.log(concentration0) reparam_module0.log_rate.data = torch.log(rate) gamma_sample0 = reparam_module0(n_sample) func_val0 = func(gamma_sample0) reparam_module1 = GammaReparametrizedSample(torch.Size([n_batch])) reparam_module1.log_shape.data = torch.log(concentration1) reparam_module1.log_rate.data = torch.log(rate) gamma_sample1 = reparam_module1(n_sample)