class RepNormal(torch.nn.Module): def __init__(self): super().__init__() self.mu = Parameter(FloatTensor([0.0])) self.log_variance = Parameter(FloatTensor([0.0])) def __call__(self): z = Variable(torch.randn(1)) return self.mu + self.log_variance.exp() * z def _repr_pretty_(self, p, cycle): p.text("mu = {}".format(self.mu)) p.text("std = {}".format(self.log_variance.exp()))
class RelaxedBernoulli(Distribution): def __init__(self, probs=[0.5], temperature=1.0, learnable=True): super().__init__() self.n_dims = len(probs) self.temperature = torch.tensor(temperature) if not isinstance(probs, torch.Tensor): probs = torch.tensor(probs) self.logits = utils.log(probs.float()) if learnable: self.logits = Parameter(self.logits) def log_prob(self, value): model = dists.RelaxedBernoulli(self.temperature, self.probs) return model.log_prob(value).sum(-1) def sample(self, batch_size): model = dists.RelaxedBernoulli(self.temperature, self.probs) return model.sample((batch_size,)) @property def probs(self): return self.logits.exp() def get_parameters(self): if self.n_dims == 1: return {'probs':self.probs.item()} return {'probs':self.probs.detach().numpy()}
class Binomial(Distribution): def __init__(self, total_count=10, probs=[0.5], learnable=True): super().__init__() if not isinstance(probs, torch.Tensor): total_count = torch.tensor(total_count) if not isinstance(probs, torch.Tensor): probs = torch.tensor(probs).view(-1) self.n_dims = len(probs) self.total_count = total_count.float() self.logits = log(probs.float()) if learnable: self.total_count = Parameter(self.total_count) self.logits = Parameter(self.logits) def log_prob(self, value): return BN(self.total_count, probs=self.probs).log_prob(value).sum(-1) def sample(self, batch_size): return BN(self.total_count, probs=self.probs).sample((batch_size, )) def entropy(self): return 0.5 * (2 * pi * e * self.total_count * self.probs * (1 - self.probs)).log() @property def expectation(self): return self.total_count * self.probs @property def mode(self): return ((self.total_count + 1) * self.probs).floor() @property def median(self): return (self.total_count * self.probs).floor() @property def variance(self): return self.total_count * self.probs * (1 - self.probs) @property def skewness(self): return (1 - 2 * self.probs) / (self.total_count * self.probs * (1 - self.probs)).sqrt() @property def kurtosis(self): pq = self.probs * (1 - self.probs) return (1 - 6 * pq) / (self.total_count * pq) @property def probs(self): return self.logits.exp() def get_parameters(self): return { 'total_count': self.total_count.detach().numpy(), 'probs': self.probs.detach().numpy() }
class LinReg(torch.nn.Module): def __init__(self, p): super(LinReg, self).__init__() self.b = Param(torch.randn(p, requires_grad=True)) self.log_sig = Param(torch.randn(1, requires_grad=True)) def forward(self, y, X): return Normal(X.matmul(self.b), self.log_sig.exp()).log_prob(y).mean()
class GaussianLayer1(Module): def __init__(self, in_features, in_dim, out_features, out_dim, num_components, sigma_gamma): super(GaussianLayer1, self).__init__() self.in_features = in_features self.out_features = out_features self.num_components = num_components self.sigma_gamma = sigma_gamma self.in_dim = in_dim self.out_dim = out_dim self.mus = Parameter( torch.rand(num_components, self.in_dim + self.out_dim)) # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma) log_var = ( 1 / torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log() self.log_vars = Parameter(log_var - torch.rand(num_components, self.in_dim + self.out_dim)) self.weights = Parameter((torch.rand(num_components) - 0.5)) self.bias = Parameter((torch.rand(out_features)) - 0.5) self.x_in_idx = torch.linspace(0, 1, self.in_features) self.x_out_idx = torch.linspace(0, 1, self.out_features) def forward(self, x): self.x_in_idx = self.x_in_idx.to(self.mus.device) self.x_out_idx = self.x_out_idx.to(self.mus.device) vars = self.log_vars.exp() sigmas = vars.sqrt() t0 = time.time() mus_x0 = self.mus[:, 0] deltas_x0 = (self.x_in_idx.view(-1, 1) - mus_x0).pow(2) log_prob_x0 = -deltas_x0 / (2 * vars[:, 0]) mus_x1 = self.mus[:, 1] deltas_x1 = (self.x_out_idx.view(-1, 1) - mus_x1).pow(2) log_prob_x1 = -deltas_x1 / (2 * vars[:, 1]) log_prob_x1 = log_prob_x1.unsqueeze_(1) log_probs = log_prob_x1 + log_prob_x0 probs = (log_probs.exp() * self.weights).sum(dim=-1) x1 = torch.mm(x, probs.transpose(1, 0)) + self.bias t1 = time.time() # print(t1-t0) return x1
class NegativeBinomial(Distribution): def __init__(self, total_count=10, probs=[0.5], learnable=True): super().__init__() if not isinstance(probs, torch.Tensor): total_count = torch.tensor(total_count) if not isinstance(probs, torch.Tensor): probs = torch.tensor(probs).view(-1) self.n_dims = len(probs) self.total_count = total_count.float() self.logits = log(probs.float()) if learnable: self.total_count = Parameter(self.total_count) self.logits = Parameter(self.logits) def log_prob(self, value): return NB(self.total_count, probs=self.probs).log_prob(value).sum(-1) def sample(self, batch_size): return NB(self.total_count, probs=self.probs).sample((batch_size, )) @property def expectation(self): return (self.probs * self.total_count) / (1 - self.probs) @property def mode(self): if self.total_count > 1: return (self.probs * (self.total_count - 1) / (1 - self.probs)).floor() return torch.tensor(0.).float() @property def variance(self): return self.probs * self.total_count / (1 - self.probs).pow(2) @property def skewness(self): return (1 + self.probs) / (self.probs * self.total_count).sqrt() @property def kurtosis(self): return 6. / self.total_count + (1 - self.probs).pow(2) / ( self.probs * self.total_count) @property def probs(self): return self.logits.exp() def get_parameters(self): return { 'total_count': self.total_count.detach().numpy(), 'probs': self.probs.detach().numpy() }
class LogNormal(Distribution): def __init__(self, loc=0., scale=1., learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(-1) self.n_dims = len(loc) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(-1) self.loc = loc.float() self._scale = utils.softplus_inverse(scale.float()) if learnable: self.loc = Parameter(self.loc) self._scale = Parameter(self._scale) def log_prob(self, value): return dists.LogNormal(self.loc, self.scale).log_prob(value).sum(-1) def sample(self, batch_size): return dists.LogNormal(self.loc, self.scale).rsample((batch_size, )) @property def expectation(self): return (self.loc + self.scale.pow(2) / 2).exp() @property def variance(self): s_square = self.scale.pow(2) return (s_square.exp() - 1) * (2 * self.loc + s_square).exp() @property def median(self): return self.loc.exp() @property def mode(self): return (self.loc - self.scale.pow(2)).exp() def entropy(self): return dists.LogNormal(self.loc, self.scale).entropy() @property def scale(self): return softplus(self._scale) def get_parameters(self): if self.n_dims == 1: return {'loc': self.loc.item(), 'scale': self.scale.item()} return { 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy() }
class GaussianLayer2(Module): def __init__(self, in_features, out_features, num_components, sigma_gamma): super(GaussianLayer2, self).__init__() self.in_features = in_features self.out_features = out_features self.num_components = num_components self.sigma_gamma = sigma_gamma self.mus = Parameter(torch.rand(num_components, 2)) # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma) log_var = ( 1 / torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log() self.log_vars = Parameter(log_var - torch.rand(num_components, 2)) self.weights = Parameter(((torch.rand(num_components)) - 0.5)) self.x_in_idx = torch.linspace(0, 1, self.in_features) self.x_out_idx = torch.linspace(0, 1, self.out_features) def forward(self, x): vars = self.log_vars.exp() sigmas = vars.sqrt() x_out = torch.zeros(x.shape[0], self.out_features) # compute Cs C = torch.zeros(x.shape[0], self.num_components) for m in range(self.num_components): mu_x0 = self.mus[m, 0] deltas = (self.x_in_idx - mu_x0).pow(2) exp_terms = -deltas / (2 * vars[m, 0]) # smth = torch.mm(x, exp_terms.exp().reshape(-1,1)) C[:, m] = torch.mm( x, exp_terms.exp().reshape(-1, 1)).squeeze() / torch.sqrt( 2. * math.pi * vars[m, 0]) * self.weights[m] for j in range(self.out_features): # calc prob of x1_i along x1 axis by summing up probs of x1_i given each of m Gaussian components dist = Normal(self.mus[:, 1], sigmas[:, 1]) # X1 marginal log_probs = dist.log_prob(self.x_out_idx[j]) probs = log_probs.exp() weighted_probs = probs * C result_prob = weighted_probs.sum(dim=-1) x_out[:, j] = result_prob / (self.in_features * self.weights.sum()) return x_out
class LogCauchy(Distribution): def __init__(self, loc=0., scale=1., learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(-1) self.n_dims = len(loc) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(-1) self.loc = loc.float() self._scale = utils.softplus_inverse(scale.float()) if learnable: self.loc = Parameter(self.loc) self._scale = Parameter(self._scale) def log_prob(self, value): model = TransformDistribution( Cauchy(self.loc, self.scale, learnable=False), [Exp()]) return model.log_prob(value) def sample(self, batch_size): model = TransformDistribution( Cauchy(self.loc, self.scale, learnable=False), [Exp()]) return model.sample(batch_size) def cdf(self, value): std_term = (value.log() - self.loc) / self.scale return (1. / math.pi) * torch.atan(std_term) + 0.5 def icdf(self, value): cauchy_icdf = torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc return cauchy_icdf.exp() @property def scale(self): return softplus(self._scale) @property def median(self): return self.loc.exp() def get_parameters(self): if self.n_dims == 1: return {'loc': self.loc.item(), 'scale': self.scale.item()} return { 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy() }
class MaskedConv2d(nn.Module): """ Conv2d with mask and weight normalization. """ def __init__(self, in_channels, out_channels, kernel_size, mask_type='A', order='A', masked_channels=None, stride=1, dilation=1, groups=1): super(MaskedConv2d, self).__init__() assert mask_type in {'A', 'B'} assert order in {'A', 'B'} self.mask_type = mask_type self.order = order kernel_size = _pair(kernel_size) for k in kernel_size: assert k % 2 == 1, 'kernel cannot include even number: {}'.format( self.kernel_size) padding = (kernel_size[0] // 2, kernel_size[1] // 2) stride = _pair(stride) dilation = _pair(dilation) if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels # masked all input channels by default masked_channels = in_channels if masked_channels is None else masked_channels self.masked_channels = masked_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.weight_v = Parameter( torch.Tensor(out_channels, in_channels // groups, *kernel_size)) self.weight_g = Parameter(torch.Tensor(out_channels, 1, 1, 1)) self.bias = Parameter(torch.Tensor(out_channels)) self.register_buffer('mask', torch.ones(self.weight_v.size())) _, _, kH, kW = self.weight_v.size() mask = np.ones([*self.mask.size()], dtype=np.float32) mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0 mask[:, :masked_channels, kH // 2 + 1:] = 0 # reverse order if order == 'B': reverse_mask = mask[:, :, ::-1, :] reverse_mask = reverse_mask[:, :, :, ::-1] mask = reverse_mask.copy() self.mask.copy_(torch.from_numpy(mask).float()) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weight_v, mean=0.0, std=0.05) self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0).data + 1e-8 self.weight_g.data.copy_(_norm.log()) nn.init.constant_(self.bias, 0) def init(self, x, init_scale=1.0): with torch.no_grad(): # [batch, n_channels, H, W] out = self(x) n_channels = out.size(1) out = out.transpose(0, 1).contiguous().view(n_channels, -1) # [n_channels] mean = out.mean(dim=1) std = out.std(dim=1) inv_stdv = init_scale / (std + 1e-6) self.weight_g.add_(inv_stdv.log().view(n_channels, 1, 1, 1)) self.bias.add_(-mean).mul_(inv_stdv) return self(x) def forward(self, input): self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0) + 1e-8 weight = self.weight_v * (self.weight_g.exp() / _norm) return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def extra_repr(self): s = ( '{in_channels}({masked_channels}), {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ', type={mask_type}, order={order}' return s.format(**self.__dict__)
class GaussianLayerBase(Module): def __init__(self, in_shape, out_shape, num_components, sigma_gamma): super(GaussianLayerBase, self).__init__() self.in_shape = in_shape self.out_shape = out_shape self.num_components = num_components self.sigma_gamma = sigma_gamma self.in_dim = len(self.in_shape) self.out_dim = len(self.out_shape) self.num_in_units = np.prod(self.in_shape) self.num_out_units = np.prod(self.out_shape) self.mus = Parameter( torch.rand(num_components, self.in_dim + self.out_dim)) # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma) log_var = ( 1 / torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log() self.log_vars = Parameter(log_var - torch.rand(num_components, self.in_dim + self.out_dim)) self.weights = Parameter((torch.rand(num_components) - 0.5)) self.bias = Parameter((torch.rand(self.num_out_units)) - 0.5) self.x_in_idx = [ torch.linspace(0, 1, in_shape_i) for in_shape_i in self.in_shape ] self.x_out_idx = [ torch.linspace(0, 1, out_shape_i) for out_shape_i in self.out_shape ] def forward(self, x): batch_size = x.shape[0] vars = self.log_vars.exp() sigmas = vars.sqrt() t0 = time.time() log_probs = torch.zeros((self.num_components)) for i in range(self.in_dim): mus_x_i = self.mus[:, i] deltas_x_i = (self.x_in_idx[i].view(-1, 1) - mus_x_i).pow(2) log_prob_x_i = -deltas_x_i / (2 * vars[:, i]) # log_prob_x_i = log_prob_x_i.transpose() log_probs = log_probs.unsqueeze_(-2) log_probs = log_probs + log_prob_x_i # log_prob_x0 = log_prob_x_i for j in range(self.out_dim): mus_x_j = self.mus[:, self.in_dim + j] deltas_x_j = (self.x_out_idx[j].view(-1, 1) - mus_x_j).pow(2) log_prob_x_j = -deltas_x_j / (2 * vars[:, self.in_dim + j]) # log_prob_x_j = log_prob_x_j.transpose(0,1) log_probs = log_probs.unsqueeze_(-2) log_probs = log_probs + log_prob_x_j # log_prob_x1 = log_prob_x_j.unsqueeze_(1) # log_probs = log_prob_x1 + log_prob_x0 probs = (log_probs.exp() * self.weights).sum(dim=-1) probs = probs.view(self.num_in_units, self.num_out_units) x = x.view(batch_size, self.num_in_units) x_out = torch.mm(x, probs) + self.bias x_out = x_out.view(batch_size, *self.out_shape) t1 = time.time() # print(t1-t0) return x_out def to(self, *args, **kwargs): super().to(*args, **kwargs) if self.device != self.x_in_idx[0].device: self.x_in_idx = [ self.x_in_idx[i].to(self.device) for i in range(len(self.x_in_idx)) ] self.x_out_idx = [ self.x_out_idx[i].to(self.device) for i in range(len(self.x_out_idx)) ]
class BayesLinearMF(Module): r""" Applies Bayesian Linear Arguments: .. note:: other arguments are following linear of pytorch 1.2.0. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py """ __constants__ = ['bias', 'in_features', 'out_features'] def __init__(self, single_eps, local_reparam, in_features, out_features, bias=True, deterministic=False): super(BayesLinearMF, self).__init__() self.in_features = in_features self.out_features = out_features self.single_eps = single_eps self.local_reparam = local_reparam self.deterministic = deterministic self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_log_sigma = Parameter( torch.Tensor(out_features, in_features)) if bias is None or bias is False: self.bias = False else: self.bias = True if self.bias: self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_log_sigma = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias_mu', None) self.register_parameter('bias_log_sigma', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight_mu.size(1)) self.weight_mu.data.uniform_(-stdv, stdv) self.weight_log_sigma.data.fill_(-5) if self.bias: self.bias_mu.data.uniform_(-stdv, stdv) self.bias_log_sigma.data.fill_(-5) def forward(self, input): r""" Overriden. """ if self.single_eps or self.deterministic: if self.deterministic: weight = self.weight_mu bias = self.bias_mu if self.bias else None else: weight = self.weight_mu + torch.exp( self.weight_log_sigma) * torch.randn(self.out_features, self.in_features, device=input.device, dtype=input.dtype) bias = self.bias_mu + torch.exp( self.bias_log_sigma) * torch.randn( self.out_features, device=input.device, dtype=input.dtype) if self.bias else None out = F.linear(input, weight, bias) else: if self.local_reparam: act_mu = F.linear(input, self.weight_mu, self.bias_mu if self.bias else None) act_var = F.linear( input**2, self.weight_log_sigma.exp()**2, self.bias_log_sigma.exp()**2 if self.bias else None) act_std = torch.sqrt(act_var + 1e-16) out = act_mu + act_std * torch.randn_like(act_mu) else: weight = self.weight_mu + torch.exp( self.weight_log_sigma) * torch.randn(input.shape[0], self.out_features, self.in_features, device=input.device, dtype=input.dtype) out = torch.bmm(weight, input.unsqueeze(2)).squeeze() if self.bias: bias = self.bias_mu + torch.exp( self.bias_log_sigma) * torch.randn(input.shape[0], self.out_features, device=input.device, dtype=input.dtype) out = out + bias return out def extra_repr(self): r""" Overriden. """ return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None)
class MaskedLinear(nn.Module): """ masked linear module with weight normalization """ def __init__(self, in_features, out_features, mask_type, total_units, max_units=None, bias=True): """ Args: in_features: number of units in the inputs out_features: number of units in the outputs. max_units: the list containing the maximum units each input unit depends on. mask_type: type of the masked linear. total_units: the total number of units to assign. bias: using bias vector. """ super(MaskedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight_v = Parameter(torch.Tensor(out_features, in_features)) self.weight_g = Parameter(torch.Tensor(out_features, 1)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) layer_type, order = mask_type self.layer_type = layer_type self.order = order assert layer_type in {'input-hidden', 'hidden-hidden', 'hidden-output', 'input-output'} assert order in {'A', 'B'} self.register_buffer('mask', self.weight_v.data.clone()) # override the max_units for input layer if layer_type.startswith('input'): max_units = np.arange(in_features) + 1 else: assert max_units is not None and len(max_units) == in_features if layer_type.endswith('output'): assert out_features > total_units self.max_units = np.arange(out_features) self.max_units[total_units:] = total_units else: units_per_units = float(total_units) / out_features self.max_units = np.zeros(out_features, dtype=np.int32) for i in range(out_features): self.max_units[i] = np.ceil((i + 1) * units_per_units) mask = np.zeros([out_features, in_features], dtype=np.float32) for i in range(out_features): for j in range(in_features): mask[i, j] = float(self.max_units[i] >= max_units[j]) # reverse order if order == 'B': reverse_mask = mask[::-1, :] reverse_mask = reverse_mask[:, ::-1] mask = np.copy(reverse_mask) self.mask.copy_(torch.from_numpy(mask).float()) self.reset_parameters() self._init = True def reset_parameters(self): nn.init.normal_(self.weight_v, mean=0.0, std=0.05) self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0).data + 1e-8 self.weight_g.data.copy_(_norm.log()) if self.bias is not None: nn.init.constant_(self.bias, 0.) def initialize(self, x, init_scale=1.0): with torch.no_grad(): # [batch, out_features] out = self(x) # [out_features] mean = out.mean(dim=0) std = out.std(dim=0) std = std + std.le(0).float() inv_stdv = init_scale / (std + 1e-6) self.weight_g.add_(inv_stdv.log().unsqueeze(1)) if self.bias is not None: self.bias.add_(-mean).mul_(inv_stdv) return self(x) def forward(self, input): self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0) + 1e-8 weight = self.weight_v * (self.weight_g.exp() / _norm) return F.linear(input, weight, self.bias) @overrides def extra_repr(self): return 'in_features={}, out_features={}, bias={}, type={}, order={}'.format( self.in_features, self.out_features, self.bias is not None, self.layer_type, self.order )
class Geometric(Distribution): def __init__(self, probs=0.5, learnable=True): super().__init__() if not isinstance(probs, torch.Tensor): probs = torch.tensor(probs).view(-1) self.n_dims = len(probs) self.logits = log(probs.float()) if learnable: self.logits = Parameter(self.logits) def log_prob(self, value): return (value * (-self.probs).log1p() + log(self.probs)).sum(-1) def sample(self, batch_size): u = torch.rand((batch_size, self.n_dims)) return (u.log() / (-self.probs).log1p()).floor() def cdf(self, value): return 1 - (1 - self.probs).pow(value + 1.) def icdf(self, value): return ((-value).log1p() / (-self.probs).log1p()) - 1. def entropy(self): q = (1. - self.probs) return -(q * utils.log(q) + self.probs * utils.log(self.probs)) / self.probs def kl(self, other): if isinstance(other, Geometric): return (-self.entropy() - (-other.probs).log1p() / self.probs - other.logits).sum() return None @property def expectation(self): return 1. / self.probs - 1. @property def variance(self): return (1. / self.probs - 1.) / self.probs @property def mode(self): return torch.tensor(0.).float() @property def skewness(self): return (2 - self.probs) / (1 - self.probs).sqrt() @property def kurtosis(self): return 6. + (self.probs.pow(2) / (1. - self.probs)) @property def median(self): return (-1 / (-self.probs).log1p()).ceil() - 1. @property def probs(self): return self.logits.exp() def get_parameters(self): return {'probs': self.probs.detach().numpy()}
class ActNormFlow(Flow): def __init__(self, in_features, inverse=False): super(ActNormFlow, self).__init__(inverse) self.in_features = in_features self.log_scale = Parameter(torch.Tensor(in_features)) self.bias = Parameter(torch.Tensor(in_features)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ out = input * self.log_scale.exp() + self.bias logdet = self.log_scale.sum(dim=0, keepdim=True) if input.dim() > 2: num = np.prod(input.size()[1:-1]) logdet = logdet * num.astype(float) return out, logdet @overrides def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ out = input - self.bias out = out.div(self.log_scale.exp() + 1e-8) logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0 if input.dim() > 2: num = np.prod(input.size()[1:-1]) logdet = logdet * num.astype(float) return out, logdet @overrides def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: data: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ with torch.no_grad(): out, _ = self.forward(data) mean = out.view(-1, self.in_features).mean(dim=0) std = out.view(-1, self.in_features).std(dim=0) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data) @overrides def extra_repr(self): return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) @classmethod def from_params(cls, params: Dict) -> "ActNormFlow": return ActNormFlow(**params)
class ActNorm2dFlow(Flow): def __init__(self, in_channels, inverse=False): super(ActNorm2dFlow, self).__init__(inverse) self.in_channels = in_channels self.log_scale = Parameter(torch.Tensor(in_channels, 1, 1)) self.bias = Parameter(torch.Tensor(in_channels, 1, 1)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, in_channels, H, W] Returns: out: Tensor , logdet: Tensor out: [batch, in_channels, H, W], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ batch, channels, H, W = input.size() out = input * self.log_scale.exp() + self.bias logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * W) return out, logdet @overrides def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, in_channels, H, W] Returns: out: Tensor , logdet: Tensor out: [batch, in_channels, H, W], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ batch, channels, H, W = input.size() out = input - self.bias out = out.div(self.log_scale.exp() + 1e-8) logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * -W) return out, logdet @overrides def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # [batch, n_channels, H, W] out, _ = self.forward(data) out = out.transpose(0, 1).contiguous().view(self.in_channels, -1) # [n_channels, 1, 1] mean = out.mean(dim=1).view(self.in_channels, 1, 1) std = out.std(dim=1).view(self.in_channels, 1, 1) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data) @overrides def extra_repr(self): return 'inverse={}, in_channels={}'.format(self.inverse, self.in_channels) @classmethod def from_params(cls, params: Dict) -> "ActNorm2dFlow": return ActNorm2dFlow(**params)
class ActNorm1dFlow(Flow): def __init__(self, in_features, inverse=False): super(ActNorm1dFlow, self).__init__(inverse) self.in_features = in_features self.log_scale = Parameter(torch.Tensor(in_features)) self.bias = Parameter(torch.Tensor(in_features)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() out = input * self.log_scale.exp() + self.bias if mask is not None: out = out * mask.unsqueeze(dim - 1) logdet = self.log_scale.sum(dim=0, keepdim=True) if dim > 2: num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet @overrides def backward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() out = (input - self.bias).div(self.log_scale.exp() + 1e-8) if mask is not None: out = out * mask.unsqueeze(dim - 1) logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0 if input.dim() > 2: num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet @overrides def init(self, data, mask: Union[torch.Tensor, None] = None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: data: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] init_scale: float initial scale Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ with torch.no_grad(): # [batch * N1 * ... * Nl, in_features] out, _ = self.forward(data, mask=mask) out = out.view(-1, self.in_features) mean = out.mean(dim=0) std = out.std(dim=0) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data, mask=mask) @overrides def extra_repr(self): return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) @classmethod def from_params(cls, params: Dict) -> "ActNorm1dFlow": return ActNorm1dFlow(**params)
class ADKL_KRR_net(MetaNetwork): TRAIN = 0 DESCR = 1 BOTH = 2 NB_KERNEL_PARAMS = 1 def __init__(self, input_features_extractor_params, target_features_extractor_params, condition_on='train', task_descr_extractor_params=None, dataset_encoder_params=None, hp_mode='fixe', l2=0.1, device='cuda', task_encoder_reg=0., n_pseudo_inputs=0, pseudo_inputs_reg=0, stationary_kernel=False): """ In the constructor we instantiate an lstm module """ super(ADKL_KRR_net, self).__init__() if condition_on.lower() in ['train', 'train_samples']: assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified' self.condition_on = self.TRAIN elif condition_on.lower() in ['descr', 'task_descr']: assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified' self.condition_on = self.DESCR elif condition_on.lower() in ['both']: assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified' assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified' self.condition_on = self.BOTH else: raise ValueError('Invalid option for parameter condition_on') self.task_encoder_reg = task_encoder_reg if input_features_extractor_params.get('pooling_fn', 0) is None: pooling = GlobalAvgPool1d(dim=1) spectral_kernel = True else: pooling = None spectral_kernel = False task_encoder = TaskEncoderNet( input_features_extractor_params, target_features_extractor_params, dataset_encoder_params, complement_module_input_fextractor=pooling) self.features_extractor = task_encoder.input_fextractor fe_dim = self.features_extractor.output_dim self.task_descr_extractor = None tde_dim, de_dim = 0, 0 if self.condition_on in [self.DESCR, self.BOTH]: self.task_descr_extractor = FeaturesExtractorFactory()( **task_descr_extractor_params) tde_dim = self.task_descr_extractor.output_dim if self.condition_on in [self.TRAIN, self.BOTH]: self.dataset_encoder = task_encoder de_dim = self.dataset_encoder.output_dim self.l2 = l2 self.pseudo_inputs_reg = pseudo_inputs_reg self.hp_mode = hp_mode self.device = device if not stationary_kernel: self.kernel_network = NonStationaryKernel(fe_dim, de_dim + tde_dim, fe_dim, spectral_kernel) else: self.kernel_network = StationaryKernel(fe_dim, de_dim + tde_dim, fe_dim, spectral_kernel) if n_pseudo_inputs > 0: if spectral_kernel: self.pseudo_inputs = Parameter( torch.Tensor(n_pseudo_inputs, fe_dim)).to(device) else: self.pseudo_inputs = Parameter( torch.Tensor(n_pseudo_inputs, fe_dim)).to(device) else: self.pseudo_inputs = None self.phis_train_mean, self.phis_train_std = 0, 0 if hp_mode.lower() in ['learn', 'learned', 'l']: self.hp_mode = 'l' elif hp_mode.lower() in [ 'predicted', 'predict', 'p', 't', 'task-specific', 'per-task' ]: self.hp_mode = 't' d = (de_dim if de_dim else 0) + (tde_dim if tde_dim else 0) self.hp_net = Linear(d, self.NB_KERNEL_PARAMS) else: raise Exception('hp_mode should be one of those: fixe, learn, cv') self._init_kernel_params(device) def _init_kernel_params(self, device): if self.pseudo_inputs is not None: init.kaiming_uniform_(self.pseudo_inputs, a=math.sqrt(5)) self.l2 = torch.FloatTensor([self.l2]).to(device) if self.hp_mode == 'l': self.l2 = Parameter(self.l2) def compute_batch_gram_matrix(self, x, y, task_phis): k_ = self.kernel_network(x, y, task_phis) if self.pseudo_inputs is not None: ps = self.pseudo_inputs.unsqueeze(0).expand( x.shape[0], *self.pseudo_inputs.shape) k_g = self.kernel_network(x, ps, task_phis) k_ = torch.cat((k_, k_g), dim=-1) return k_ def set_kernel_params(self, task_phis): if self.hp_mode == 't': self.l2 = self.hp_net(task_phis).squeeze(-1) l2 = hardtanh(self.l2.exp(), 1e-4, 1e1) return l2 def add_pseudo_inputs_loss(self, loss): n = self.pseudo_inputs.shape[0] d = self.pseudo_inputs.shape[-1] p = self.pseudo_inputs.reshape(-1, d) # reg = torch.exp(-0.5 * (p.unsqueeze(2) - p.unsqueeze(1)).pow(2).sum(-1)) # reg = torch.tril(reg).sum() / (n * (n - 1)) if self.pseudo_inputs.dim() == 2: pi_mean = torch.mean(self.pseudo_inputs, dim=0) pi_std = torch.std(self.pseudo_inputs, dim=0) elif self.pseudo_inputs.dim() == 3: pi_mean = torch.mean(self.pseudo_inputs, dim=(0, 1)), pi_std = torch.std(self.pseudo_inputs, dim=(0, 1)) else: raise Exception( 'Pseudo inputs: the number of dimensions is incorrect') kl = kl_divergence( MultivariateNormal(pi_mean, torch.diag(pi_std + 0.1)), MultivariateNormal(self.phis_train_mean, torch.diag(self.phis_train_std + 0.1))) res = self.pseudo_inputs_reg * kl return loss + res, res def add_task_encoder_loss(self, loss): reg = self.task_encoder_reg * self.task_encoder_loss return loss + reg, reg def get_alphas(self, phis, ys, masks, task_phis=None): l2 = self.set_kernel_params(task_phis) bsize, n_train = phis.shape[:2] k_ = self.compute_batch_gram_matrix(phis, phis, task_phis=task_phis) k = torch.bmm(k_, k_.transpose(1, 2)) k_mask = masks[:, None, :] * masks[:, :, None] k = k * k_mask identity = torch.eye(n_train, device=k.device).unsqueeze(0).expand( (bsize, n_train, n_train)) batch_K_inv = torch.inverse(k + l2.unsqueeze(1).unsqueeze(1) * identity) alphas = torch.bmm(batch_K_inv, ys) return alphas, k_ def get_preds(self, alphas, K_train, phis_train, masks_train, phis_test, masks_test, task_phis=None): k = self.compute_batch_gram_matrix(phis_test, phis_train, task_phis=task_phis) k = torch.bmm(k, K_train.transpose(1, 2)) k_mask = masks_test[:, :, None] * masks_train[:, None, :] k = k * k_mask preds = torch.bmm(k, alphas) return preds def get_task_phis(self, tasks_descr, xs_train, ys_train, mask_train): if self.condition_on == self.DESCR: task_phis = self.task_descr_extractor(tasks_descr) elif self.condition_on == self.TRAIN: task_phis = self.dataset_encoder(None, xs_train, ys_train, mask_train) else: task_phis = torch.cat([ self.task_descr_extractor(tasks_descr), self.dataset_encoder(None, xs_train, ys_train, mask_train) ], dim=1) return task_phis def get_phis(self, xs, train=False): phis = self.features_extractor(xs.reshape(-1, xs.shape[2])) if train: alpha = 0.8 if phis.dim() == 2: self.phis_train_mean = ( 1 - alpha) * self.phis_train_mean + alpha * torch.mean( phis, dim=0).detach() self.phis_train_std = ( 1 - alpha) * self.phis_train_std + alpha * torch.std( phis, dim=0).detach() elif phis.dim() == 3: self.phis_train_mean = ( 1 - alpha) * self.phis_train_mean + alpha * torch.mean( phis, dim=(0, 1)).detach() self.phis_train_std = ( 1 - alpha) * self.phis_train_std + alpha * torch.std( phis, dim=(0, 1)).detach() phis = phis.reshape(*xs.shape[:2], *phis.shape[1:]) return phis def forward(self, episodes): if self.condition_on == self.TRAIN: train, test = pack_episodes(episodes, return_tasks_descr=False) xs_train, ys_train, lens_train, mask_train = train xs_test, lens_test, mask_test = test task_phis = self.get_task_phis(None, xs_train, ys_train, mask_train) else: train, test, tasks_descr = pack_episodes(episodes, return_tasks_descr=True) xs_train, ys_train, lens_train, mask_train = train xs_test, lens_test, mask_test = test task_phis = self.get_task_phis(tasks_descr, xs_train, ys_train, mask_train) phis_train, phis_test = self.get_phis( xs_train, train=True), self.get_phis(xs_test, train=False) # training alphas, K_train = self.get_alphas(phis_train, ys_train, mask_train, task_phis) # testing preds = self.get_preds(alphas, K_train, phis_train, mask_train, phis_test, mask_test, task_phis) self.compute_task_encoder_loss_last_batch(episodes) if isinstance(preds, tuple): return [ tuple(x[:n] for x in pred) for n, pred in zip(lens_test, preds) ] else: return [x[:n] for n, x in zip(lens_test, preds)] def compute_task_encoder_loss_last_batch(self, episodes): # train x test set_code = self.dataset_encoder(episodes) y_preds_class = torch.arange(len(set_code)) if set_code.is_cuda: y_preds_class = y_preds_class.to('cuda') accuracy = (set_code.argmax(dim=1) == y_preds_class).sum().item() / len(set_code) b = set_code.size(0) mi = set_code.diagonal().mean() \ - torch.log((set_code * (1 - torch.eye(b))).exp().sum() / (b * (b - 1))) loss = -mi self.accuracy = accuracy self.task_encoder_loss = loss
class Bernoulli(Distribution): def __init__(self, probs=[0.5], learnable=True): super().__init__() if not isinstance(probs, torch.Tensor): probs = torch.tensor(probs).view(-1) self.n_dims = len(probs) self.logits = log(probs.float()) if learnable: self.logits = Parameter(self.logits) def log_prob(self, value): q = 1. - self.probs return (value * (self.probs + eps).log() + (1. - value) * (q + eps).log()).sum(-1) def sample(self, batch_size): return torch.bernoulli( self.probs.unsqueeze(0).expand((batch_size, *self.probs.shape))) def entropy(self): q = 1. - self.probs return -q * (q + eps).log() - self.probs * (self.probs + eps).log() def kl(self, other): if isinstance(other, Bernoulli): t1 = self.probs * (self.probs / other.probs).log() t1[other.probs == 0] = inf t1[self.probs == 0] = 0 t2 = (1 - self.probs) * ((1 - self.probs) / (1 - other.probs)).log() t2[other.probs == 1] = inf t2[self.probs == 1] = 0 return (t1 + t2).sum() if isinstance(other, Poisson): return (-self.entropy() - (self.probs * other.rate.log() - other.rate)).sum() return None @property def expectation(self): return self.probs @property def variance(self): return self.probs * (1 - self.probs) @property def skewness(self): return (1 - 2 * self.probs) / (self.probs * (1 - self.probs)).sqrt() @property def kurtosis(self): q = (1 - self.probs) return (1 - 6. * self.probs * q) / (self.probs * q) @property def probs(self): return self.logits.exp() def get_parameters(self): if self.n_dims == 1: return {'probs': self.probs.detach().item()} return {'probs': self.probs.detach().numpy()}