class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule): """ ScriptModule implementation of a learned scale factor initialized from runtime statistics. The implementation works in two phases. During the first phase, statistics are collected in the same fashion as batchnorm, meaning that while the module is in training mode a set of per-batch statistics are computed and returned, while in background an average of them is retained and returned in inference mode. During the second phase, the average accumulated during the first phase is used to initialize a learned torch.nn.Parameter, and then the behaviour is the same as ParameterScaling. Args: collect_stats_steps (int): Number of calls to the forward method in training mode to collect statistics for. scaling_stats_impl (Module): Implementation of the statistics computed during the collection phase. scaling_stats_input_view_shape_impl (Module): Implementation of the view applied to the runtime input during the statistics collection phase. Default: OverBatchOverTensorView(). scaling_shape (Tuple[int, ...]): shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE. restrict_scaling_impl (Module): restrict the learned scale factor according to some criteria. Default: None input before going into scaling_stats_input_view_shape_impl. Default: None scaling_stats_momentum: float = Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM. scaling_min_val (float): force a lower-bound on the learned scale factor. Default: None. Returns: Tensor: learned scale factor wrapped in a float torch.tensor. Raises: RuntimeError: if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None Examples: >>> scaling_impl = ParameterFromRuntimeStatsScaling(collect_stats_steps=1, scaling_stats_impl=AbsMax()) >>> scaling_impl.training True >>> x = torch.arange(-3, 2, 0.1) >>> scaling_impl(x) tensor(3.) >>> scaling_impl(torch.randn_like(x)) tensor(3., grad_fn=<AbsBinarySignGradFnBackward>) Note: Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict. Note: Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == 'PARAMETER_FROM_STATS' == 'parameter_from_stats' when applied to runtime values (inputs/outputs/activations) in higher-level APIs. """ __constants__ = ['collect_stats_steps', 'momentum'] def __init__( self, collect_stats_steps: int, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' self.collect_stats_steps = collect_stats_steps self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl self.stats = _Stats(scaling_stats_impl, scaling_shape) self.momentum = scaling_stats_momentum self.register_buffer('buffer', torch.full(scaling_shape, 1.0)) self.value = Parameter(torch.full(scaling_shape, 1.0)) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) if restrict_scaling_impl is not None: self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() else: self.restrict_inplace_preprocess = Identity() self.restrict_preprocess = Identity() @brevitas.jit.script_method def training_forward(self, stats_input: Tensor) -> Tensor: if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) new_counter = self.counter + 1 if self.counter == 0: _inplace_init(self.buffer, stats.detach()) else: _inplace_update(self.buffer, stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter return stats elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) _inplace_init(self.value.detach(), self.buffer) self.counter = self.counter + 1 return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) else: return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) @brevitas.jit.script_method def forward(self, stats_input: Tensor) -> Tensor: if self.training: return self.training_forward(stats_input) else: if self.counter <= self.collect_stats_steps: out = self.buffer out = self.restrict_preprocess(out) else: out = self.value out = abs_binary_sign_grad(self.restrict_clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): output_dict = super(ParameterFromRuntimeStatsScaling, self).state_dict( destination, prefix, keep_vars) # Avoid saving the buffer del output_dict[prefix + 'buffer'] # Avoid saving the init value if self.counter == 0: del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer) return output_dict def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) value_key = prefix + 'value' # Buffer is supposed to be always missing missing_keys.remove(prefix + 'buffer') # Retrocompatibility with older ParameterScaling, for when scaling impl is switched over retrocomp_value_key = prefix + 'learned_value' if retrocomp_value_key in state_dict: state_dict[value_key] = state_dict.pop(retrocomp_value_key) # Pytorch stores training flag as a buffer with JIT enabled training_key = prefix + 'training' if training_key in missing_keys: missing_keys.remove(training_key) # disable stats collection when a pretrained value is loaded if value_key not in missing_keys: self.counter = self.collect_stats_steps + 1 if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key)
class InfiniteMixtureModel(Distribution): def __init__(self, df, loc, scale, loc_learnable=True, scale_learnable=True, df_learnable=True): super().__init__() self.loc = torch.tensor(loc).view(-1) self.n_dims = len(self.loc) if loc_learnable: self.loc = Parameter(self.loc) self._scale = utils.softplus_inverse(torch.tensor(scale).view(-1)) if scale_learnable: self._scale = Parameter(self._scale) self._df = utils.softplus_inverse(torch.tensor(df).view(-1)) if df_learnable: self._df = Parameter(self._df) def sample(self, batch_size, return_latents=False): weight_model = Gamma(self.df / 2, self.df / 2, learnable=False) latent_samples = weight_model.sample(batch_size) normal_model = Normal(self.loc.expand(batch_size), (self.scale / latent_samples).squeeze(1), learnable=False) if return_latents: return normal_model.sample(1).squeeze(0).unsqueeze( 1), latent_samples else: return normal_model.sample(1).squeeze(0).unsqueeze(1) def log_prob(self, samples, latents=None): if latents is None: raise NotImplementedError( "InfiniteMixtureModel log_prob not implemented") weight_model = Gamma(self.df / 2, self.df / 2, learnable=False) normal_model = Normal(self.loc.expand(latents.size(0)), (self.scale / latents).squeeze(1), learnable=False) return normal_model.log_prob(samples) + weight_model.log_prob(latents) @property def scale(self): return softplus(self._scale) @property def df(self): return softplus(self._df) @property def has_latents(self): return True def get_parameters(self): if self.n_dims == 1: return { "loc": self.loc.item(), "scale": self.scale.item(), "df": self.df.item(), } return { "loc": self.loc.detach().numpy(), "scale": self.scale.detach().numpy(), "df": self.df.detach().numpy(), }
class SpatialTransformerPooled3d(nn.Module): def __init__(self, in_shape, outdims, pool_steps=1, positive=False, bias=True, init_range=.05, kernel_size=2, stride=2, grid=None, stop_grad=False): super().__init__() self._pool_steps = pool_steps self.in_shape = in_shape c, t, w, h = in_shape self.outdims = outdims self.positive = positive if grid is None: self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) else: self.grid = grid self.features = Parameter(torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims)) self.register_buffer('mask', torch.ones_like(self.features)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter('bias', bias) else: self.register_parameter('bias', None) self.avg = nn.AvgPool2d(kernel_size, stride=stride, count_include_pad=False) self.init_range = init_range self.initialize() self.stop_grad = stop_grad @property def pool_steps(self): return self._pool_steps @pool_steps.setter def pool_steps(self, value): assert value >= 0 and int(value) - value == 0, 'new pool steps must be a non-negative integer' if value != self._pool_steps: print('Resizing readout features') c, t, w, h = self.in_shape outdims = self.outdims self._pool_steps = int(value) self.features = Parameter(torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims)) self.mask = torch.ones_like(self.features) self.features.data.fill_(1 / self.in_shape[0]) def initialize(self, init_noise=1e-3, grid=True): # randomly pick centers within the spatial map self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.bias.data.fill_(0) if grid: self.grid.data.uniform_(-self.init_range, self.init_range) def feature_l1(self, average=True, subs_idx=None): subs_idx = subs_idx if subs_idx is not None else slice(None) if average: return self.features[..., subs_idx].abs().mean() else: return self.features[..., subs_idx].abs().sum() def reset_fisher_prune_scores(self): self._prune_n = 0 self._prune_scores = self.features.detach() * 0 def update_fisher_prune_scores(self): self._prune_n += 1 if self.features.grad is None: raise ValueError('You need to run backward first') self._prune_scores += (0.5 * self.features.grad.pow(2) * self.features.pow(2)).detach() @property def fisher_prune_scores(self): return self._prune_scores / self._prune_n def prune(self): idx = (self.fisher_prune_scores + 1e6 * (1 - self.mask)).squeeze().argmin(dim=0) nt = idx.new seq = nt(np.arange(len(idx))) self.mask[:, idx, :, seq] = 0 self.features.data[:, idx, :, seq] = 0 def forward(self, x, shift=None, subs_idx=None): if self.stop_grad: x = x.detach() self.features.data *= self.mask if self.positive: positive(self.features) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, t, w, h = x.size() m = self._pool_steps + 1 if subs_idx is not None: feat = self.features[..., subs_idx].contiguous() outdims = feat.size(-1) feat = feat.view(1, m * c, outdims) grid = self.grid[:, subs_idx, ...] else: grid = self.grid feat = self.features.view(1, m * c, self.outdims) outdims = self.outdims if shift is None: grid = grid.expand(N * t, outdims, 1, 2) else: grid = grid.expand(N, outdims, 1, 2) grid = torch.stack([grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1) grid = grid.contiguous().view(-1, outdims, 1, 2) z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h) pools = [F.grid_sample(z, grid)] for i in range(self._pool_steps): z = self.avg(z) pools.append(F.grid_sample(z, grid)) y = torch.cat(pools, dim=1) y = (y.squeeze(-1) * feat).sum(1).view(N, t, outdims) if self.bias is not None: if subs_idx is None: y = y + self.bias else: y = y + self.bias[subs_idx] return y def __repr__(self): c, _, w, h = self.in_shape r = self.__class__.__name__ + \ ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')' if self.bias is not None: r += ' with bias' if self.stop_grad: r += ', stop_grad=True' r += '\n' for ch in self.children(): r += ' -> ' + ch.__repr__() + '\n' return r
class _NMF(Base): def __init__(self, W_size, H_size, rank): super().__init__() self.rank = rank self.W = Parameter(torch.rand(*W_size)) self.H = Parameter(torch.rand(*H_size)) def forward(self, H=None, W=None): if H is None: H = self.H if W is None: W = self.W return self.reconstruct(H, W) def reconstruct(self, H, W): raise NotImplementedError def get_W_positive(self, WH, beta, H_sum) -> (torch.Tensor, None or torch.Tensor): raise NotImplementedError def get_H_positive(self, WH, beta, W_sum) -> (torch.Tensor, None or torch.Tensor): raise NotImplementedError def fit(self, V, W=None, H=None, update_W=True, update_H=True, beta=1, tol=1e-5, max_iter=200, verbose=0, initial='random', alpha=0, l1_ratio=0 ): V = self.fix_neg(V) if W is None: pass # will do special initialization in thre future else: self.W.data.copy_(W) self.W.requires_grad = update_W if H is None: pass else: self.H.data.copy_(H) self.H.requires_grad = update_H if beta < 1: gamma = 1 / (2 - beta) elif beta > 2: gamma = 1 / (beta - 1) else: gamma = 1 l1_reg = alpha * l1_ratio l2_reg = alpha * (1 - l1_ratio) loss_scale = torch.prod(torch.tensor(V.shape)).float() H_sum, W_sum = None, None with tqdm(total=max_iter, disable=not verbose) as pbar: for n_iter in range(max_iter): if self.W.requires_grad: self.zero_grad() WH = self.reconstruct(self.H.detach(), self.W) loss = Beta_divergence(self.fix_neg(WH), V, beta) loss.backward() with torch.no_grad(): positive_comps, H_sum = self.get_W_positive(WH, beta, H_sum) _mu_update(self.W, positive_comps, gamma, l1_reg, l2_reg) W_sum = None if self.H.requires_grad: self.zero_grad() WH = self.reconstruct(self.H, self.W.detach()) loss = Beta_divergence(self.fix_neg(WH), V, beta) loss.backward() with torch.no_grad(): positive_comps, W_sum = self.get_H_positive(WH, beta, W_sum) _mu_update(self.H, positive_comps, gamma, l1_reg, l2_reg) H_sum = None loss = loss.div_(loss_scale).item() pbar.set_postfix(loss=loss) # pbar.set_description('Beta loss=%.4f' % error) pbar.update() if not n_iter: loss_init = loss elif (previous_loss - loss) / loss_init < tol: break previous_loss = loss return n_iter def fit_transform(self, *args, **kwargs): n_iter = self.fit(*args, **kwargs) return n_iter, self.forward()
class ArcMarginProduct_virface(nn.Module): def __init__(self, in_features=512, out_features=84281, s=32, m=0.5, easy_margin=False, device='cuda'): super(ArcMarginProduct_virface, self).__init__() self.in_features = in_features self.out_features = out_features self.s = s self.m = m self.weight = Parameter(torch.Tensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.dropout = nn.Dropout(0.5) self.easy_margin = easy_margin self.cos_m = math.cos(m) self.sin_m = math.sin(m) # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m self.device = device def forward(self, x, label, unlabel_x=None, unlabel_aug=None, overlap=False): if unlabel_x is not None: # filter overlap if overlap: unlabel_dot_w = F.linear(F.normalize(unlabel_x.detach()), F.normalize( self.weight.detach())) * self.s prob = F.softmax(unlabel_dot_w, dim=1) max_prob = torch.max(prob, dim=1)[0] idx_lt = max_prob.lt(0.8) idx = idx_lt else: idx = torch.ones(unlabel_x.shape[0]).bool().cuda() unlabel_data = unlabel_x[idx] weight_all = torch.cat( [F.normalize(self.weight), F.normalize(unlabel_data)], dim=0) weight_all_fix = torch.cat( [F.normalize(self.weight), F.normalize(unlabel_data).detach()], dim=0) else: # no unlabel data, return arcface weight_all = F.normalize(self.weight) # virclass # --------------------------- cos(theta) & phi(theta) --------------------------- if unlabel_x is not None: x = self.dropout(x) cosine_label = F.linear(F.normalize(x), weight_all) sine_label = 1.0 - torch.pow(cosine_label, 2) sine_label = torch.where( sine_label > 0, sine_label, torch.zeros(sine_label.size(), device=self.device)) sine_label = torch.sqrt(sine_label) phi_label = cosine_label * self.cos_m - sine_label * self.sin_m if self.easy_margin: phi_label = torch.where(cosine_label > 0, phi_label, cosine_label) else: phi_label = torch.where((cosine_label - self.th) > 0, phi_label, cosine_label - self.mm) # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cosine_label.size(), device=self.device) one_hot.scatter_(1, label.view(-1, 1).long(), 1) output_label = (one_hot * phi_label) + ((1.0 - one_hot) * cosine_label) output_label *= self.s if unlabel_aug is None or unlabel_x is None: # no aug, return virclass or arcface return output_label, None, None # virinstance aug_data = torch.cat([ unlabel_aug[i * unlabel_x.shape[0]:(i + 1) * unlabel_x.shape[0]][idx] for i in range(int(unlabel_aug.shape[0] / unlabel_x.shape[0])) ], dim=0) # --------------------------- cos(theta) & phi(theta) --------------------------- cosine_unlabel = F.linear(F.normalize(aug_data), weight_all_fix) sine_unlabel = 1.0 - torch.pow(cosine_unlabel, 2) sine_unlabel = torch.where( sine_unlabel > 0, sine_unlabel, torch.zeros(sine_unlabel.size(), device=self.device)) sine_unlabel = torch.sqrt(sine_unlabel) phi_unlabel = cosine_unlabel * self.cos_m - sine_unlabel * self.sin_m if self.easy_margin: phi_unlabel = torch.where(cosine_unlabel > 0, phi_unlabel, cosine_unlabel) else: phi_unlabel = torch.where((cosine_unlabel - self.th) > 0, phi_unlabel, cosine_unlabel - self.mm) # --------------------------- convert label to one-hot --------------------------- unlabel_label = torch.arange(self.out_features, self.out_features + unlabel_data.size(0), device=self.device).repeat( int(unlabel_aug.shape[0] / unlabel_x.shape[0])) one_hot_unlabel = torch.zeros(cosine_unlabel.size(), device=self.device) one_hot_unlabel.scatter_(1, unlabel_label.view(-1, 1).long(), 1) output_unlabel = (one_hot_unlabel * phi_unlabel) + ( (1.0 - one_hot_unlabel) * cosine_unlabel) output_unlabel *= self.s return output_label, output_unlabel, unlabel_label
class GPRegression(GPModel): r""" Gaussian Process Regression model. The core of a Gaussian Process is a covariance function :math:`k` which governs the similarity between input points. Given :math:`k`, we can establish a distribution over functions :math:`f` by a multivarite normal distribution .. math:: p(f(X)) = \mathcal{N}(0, k(X, X)), where :math:`X` is any set of input points and :math:`k(X, X)` is a covariance matrix whose entries are outputs :math:`k(x, z)` of :math:`k` over input pairs :math:`(x, z)`. This distribution is usually denoted by .. math:: f \sim \mathcal{GP}(0, k). .. note:: Generally, beside a covariance matrix :math:`k`, a Gaussian Process can also be specified by a mean function :math:`m` (which is a zero-value function by default). In that case, its distribution will be .. math:: p(f(X)) = \mathcal{N}(m(X), k(X, X)). Given inputs :math:`X` and their noisy observations :math:`y`, the Gaussian Process Regression model takes the form .. math:: f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim f + \epsilon, where :math:`\epsilon` is Gaussian noise. .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training, :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number of train inputs. Reference: [1] `Gaussian Processes for Machine Learning`, Carl E. Rasmussen, Christopher K. I. Williams :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param torch.Tensor noise: Variance of Gaussian noise of this model. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6): super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter) noise = self.X.new_tensor(1.) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", torchdist.constraints.positive) @autoname.scope(prefix="GPR") def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X) Kff.view(-1)[::N + 1] += self.jitter + self.noise # add noise to diagonal Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.X.size(0)) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = Lff.pow(2).sum(dim=-1) return f_loc, f_var else: return pyro.sample( "y", dist.MultivariateNormal(f_loc, scale_tril=Lff).expand_by( self.y.shape[:-1]).to_event(self.y.dim() - 1), obs=self.y) @autoname.scope(prefix="GPR") def guide(self): self.set_mode("guide") def forward(self, Xnew, full_cov=False, noiseless=True): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, \epsilon) = \mathcal{N}(loc, cov). .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :param bool noiseless: A flag to decide if we want to include noise in the prediction output or not. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view( -1)[::N + 1] += self.jitter + self.noise # add noise to the diagonal Lff = Kff.cholesky() y_residual = self.y - self.mean_function(self.X) loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff, full_cov, jitter=self.jitter) if full_cov and not noiseless: M = Xnew.size(0) cov = cov.contiguous() cov.view(-1, M * M)[:, ::M + 1] += self.noise # add noise to the diagonal if not full_cov and not noiseless: cov = cov + self.noise return loc + self.mean_function(Xnew), cov def iter_sample(self, noiseless=True): r""" Iteratively constructs a sample from the Gaussian Process posterior. Recall that at test input points :math:`X_{new}`, the posterior is multivariate Gaussian distributed with mean and covariance matrix given by :func:`forward`. This method samples lazily from this multivariate Gaussian. The advantage of this approach is that later query points can depend upon earlier ones. Particularly useful when the querying is to be done by an optimisation routine. .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param bool noiseless: A flag to decide if we want to add sampling noise to the samples beyond the noise inherent in the GP posterior. :returns: sampler :rtype: function """ noise = self.noise.detach() X = self.X.clone().detach() y = self.y.clone().detach() N = X.size(0) Kff = self.kernel(X).contiguous() Kff.view(-1)[::N + 1] += noise # add noise to the diagonal outside_vars = {"X": X, "y": y, "N": N, "Kff": Kff} def sample_next(xnew, outside_vars): """Repeatedly samples from the Gaussian process posterior, conditioning on previously sampled values. """ warn_if_nan(xnew) # Variables from outer scope X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars[ "Kff"] # Compute Cholesky decomposition of kernel matrix Lff = Kff.cholesky() y_residual = y - self.mean_function(X) # Compute conditional mean and variance loc, cov = conditional(xnew, X, self.kernel, y_residual, None, Lff, False, jitter=self.jitter) if not noiseless: cov = cov + noise ynew = torchdist.Normal(loc + self.mean_function(xnew), cov.sqrt()).rsample() # Update kernel matrix N = outside_vars["N"] Kffnew = Kff.new_empty(N + 1, N + 1) Kffnew[:N, :N] = Kff cross = self.kernel(X, xnew).squeeze() end = self.kernel(xnew, xnew).squeeze() Kffnew[N, :N] = cross Kffnew[:N, N] = cross # No noise, just jitter for numerical stability Kffnew[N, N] = end + self.jitter # Heuristic to avoid adding degenerate points if Kffnew.logdet() > -15.: outside_vars["Kff"] = Kffnew outside_vars["N"] += 1 outside_vars["X"] = torch.cat((X, xnew)) outside_vars["y"] = torch.cat((y, ynew)) return ynew return lambda xnew: sample_next(xnew, outside_vars)
class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule): __constants__ = ['stats_permute_dims', 'collect_stats_steps', 'momentum'] def __init__(self, collect_stats_steps: int, restrict_scaling_impl: Module, scaling_stats_impl: Module, scaling_shape: Tuple[int, ...], scaling_stats_input_view_shape_impl: Module, scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None: raise RuntimeError( "Per channel runtime stats require a permute shape") self.collect_stats_steps = collect_stats_steps self.counter: int = brevitas.jit.Attribute(0, int) self.stats_permute_dims = scaling_stats_permute_dims self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl self.stats = _Stats(scaling_stats_impl, scaling_shape) self.momentum = scaling_stats_momentum self.value = Parameter(torch.full(scaling_shape, 1.0)) self.restrict_clamp_scaling = _RestrictClampValue( scaling_min_val, restrict_scaling_impl) self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module( ) @brevitas.jit.script_method_110_disabled def forward(self, stats_input) -> torch.Tensor: if self.training: if self.counter < self.collect_stats_steps: if self.stats_permute_dims is not None: stats_input = stats_input.permute( *self.stats_permute_dims).contiguous() stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) if self.counter == 0: self.value.detach().mul_(stats.detach()) else: self.value.detach().mul_(1 - self.momentum) self.value.detach().add_(self.momentum * stats.detach()) self.counter = self.counter + 1 return stats elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.value.detach()) self.counter = self.counter + 1 return self.restrict_clamp_scaling(torch.abs(self.value)) else: return self.restrict_clamp_scaling(torch.abs(self.value)) out = self.restrict_clamp_scaling(torch.abs(self.value)) return out def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) value_key = prefix + 'value' # Pytorch stores training flag as a buffer with JIT enabled training_key = prefix + 'training' if training_key in missing_keys: missing_keys.remove(training_key) if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key)
class SpatialTransformerPooled3d(Readout): def __init__( self, in_shape, outdims, pool_steps=1, positive=False, bias=True, init_range=0.05, kernel_size=2, stride=2, grid=None, stop_grad=False, align_corners=True, mean_activity=None, feature_reg_weight=1.0, gamma_readout=None, # depricated, use feature_reg_weight instead ): super().__init__() self._pool_steps = pool_steps self.in_shape = in_shape c, t, w, h = in_shape self.outdims = outdims self.positive = positive self.feature_reg_weight = self.resolve_deprecated_gamma_readout( feature_reg_weight, gamma_readout) self.mean_activity = mean_activity if grid is None: self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) else: self.grid = grid self.features = Parameter( torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims)) self.register_buffer("mask", torch.ones_like(self.features)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter("bias", bias) else: self.register_parameter("bias", None) self.avg = nn.AvgPool2d(kernel_size, stride=stride, count_include_pad=False) self.init_range = init_range self.initialize(mean_activity) self.stop_grad = stop_grad self.align_corners = align_corners @property def pool_steps(self): return self._pool_steps @pool_steps.setter def pool_steps(self, value): assert value >= 0 and int( value ) - value == 0, "new pool steps must be a non-negative integer" if value != self._pool_steps: print("Resizing readout features") c, t, w, h = self.in_shape outdims = self.outdims self._pool_steps = int(value) self.features = Parameter( torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims)) self.mask = torch.ones_like(self.features) self.features.data.fill_(1 / self.in_shape[0]) def initialize(self, init_noise=1e-3, grid=True, mean_activity=None): if mean_activity is None: mean_activity = self.mean_activity # randomly pick centers within the spatial map self.features.data.fill_(1 / self.in_shape[0]) if grid: self.grid.data.uniform_(-self.init_range, self.init_range) if self.bias is not None: self.initialize_bias(mean_activity=mean_activity) def feature_l1(self, reduction="sum", average=None, subs_idx=None): subs_idx = subs_idx if subs_idx is not None else slice(None) return self.apply_reduction(self.features[..., subs_idx].abs(), reduction=reduction, average=average) def regularizer(self, reduction="sum", average=None): return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight def reset_fisher_prune_scores(self): self._prune_n = 0 self._prune_scores = self.features.detach() * 0 def update_fisher_prune_scores(self): self._prune_n += 1 if self.features.grad is None: raise ValueError("You need to run backward first") self._prune_scores += (0.5 * self.features.grad.pow(2) * self.features.pow(2)).detach() @property def fisher_prune_scores(self): return self._prune_scores / self._prune_n def prune(self): idx = (self.fisher_prune_scores + 1e6 * (1 - self.mask)).squeeze().argmin(dim=0) nt = idx.new seq = nt(np.arange(len(idx))) self.mask[:, idx, :, seq] = 0 self.features.data[:, idx, :, seq] = 0 def forward(self, x, shift=None, subs_idx=None): if self.stop_grad: x = x.detach() self.features.data *= self.mask if self.positive: self.features.data.clamp_min_(0) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, t, w, h = x.size() m = self._pool_steps + 1 if subs_idx is not None: feat = self.features[..., subs_idx].contiguous() outdims = feat.size(-1) feat = feat.view(1, m * c, outdims) grid = self.grid[:, subs_idx, ...] else: grid = self.grid feat = self.features.view(1, m * c, self.outdims) outdims = self.outdims if shift is None: grid = grid.expand(N * t, outdims, 1, 2) else: grid = grid.expand(N, outdims, 1, 2) grid = torch.stack( [grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1) grid = grid.contiguous().view(-1, outdims, 1, 2) z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h) pools = [F.grid_sample(z, grid, align_corners=self.align_corners)] for i in range(self._pool_steps): z = self.avg(z) pools.append( F.grid_sample(z, grid, align_corners=self.align_corners)) y = torch.cat(pools, dim=1) y = (y.squeeze(-1) * feat).sum(1).view(N, t, outdims) if self.bias is not None: if subs_idx is None: y = y + self.bias else: y = y + self.bias[subs_idx] return y def __repr__(self): c, _, w, h = self.in_shape r = self.__class__.__name__ + " (" + "{} x {} x {}".format( c, w, h) + " -> " + str(self.outdims) + ")" if self.bias is not None: r += " with bias" if self.stop_grad: r += ", stop_grad=True" r += "\n" for ch in self.children(): r += " -> " + ch.__repr__() + "\n" return r
class ResNet(nn.Module): """ResNet Variants Parameters ---------- block : Block Class for the residual block. Options are BasicBlockV1, BottleneckV1. layers : list of int Numbers of layers in each block classes : int, default 1000 Number of classification classes. dilated : bool, default False Applying dilation strategy to pretrained ResNet yielding a stride-8 model, typically used in Semantic Segmentation. norm_layer : object Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; for Synchronized Cross-GPU BachNormalization). Reference: - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." """ # pylint: disable=unused-variable def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, num_classes=1000, dilated=False, dilation=1, deep_stem=False, stem_width=64, avg_down=False, rectified_conv=False, rectify_avg=False, avd=False, avd_first=False, final_drop=0.0, dropblock_prob=0, last_gamma=False, use_se=False, in_channels=300, word_file='/workspace/Projects/cxr/models/feature_extraction/diseases_embeddings.npy', # word_file='diseases_embeddings.npy', # word_file='/home/hoangvu/Projects/cxr/models/feature_extraction/diseases_embeddings.npy', extract_fields='0,1,2,3,4,5', agree_rate=0.5, csv_path='', norm_layer=nn.BatchNorm2d): self.cardinality = groups self.bottleneck_width = bottleneck_width # ResNet-D params self.inplanes = stem_width * 2 if deep_stem else 64 self.avg_down = avg_down self.last_gamma = last_gamma # ResNeSt params self.radix = radix self.avd = avd self.avd_first = avd_first self.use_se = use_se super(ResNet, self).__init__() self.rectified_conv = rectified_conv self.rectify_avg = rectify_avg if rectified_conv: from rfconv import RFConv2d conv_layer = RFConv2d else: conv_layer = nn.Conv2d conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} if deep_stem: self.conv1 = nn.Sequential( conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), norm_layer(stem_width), nn.ReLU(inplace=True), conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), norm_layer(stem_width), nn.ReLU(inplace=True), conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), ) else: self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, bias=False, **conv_kwargs) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) if dilated or dilation == 4: self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer, dropblock_prob=dropblock_prob) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer, dropblock_prob=dropblock_prob) elif dilation == 2: self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilation=1, norm_layer=norm_layer, dropblock_prob=dropblock_prob) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, norm_layer=norm_layer, dropblock_prob=dropblock_prob) else: self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer, dropblock_prob=dropblock_prob) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer, dropblock_prob=dropblock_prob) self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None # self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, norm_layer): m.weight.data.fill_(1) m.bias.data.zero_() num_classes = len(extract_fields.split(',')) _adj = gen_adj_num(labels=extract_fields, agree_rate=agree_rate, csv_path=csv_path) self.adj = Parameter(torch.from_numpy(_adj).float()) if not os.path.exists(word_file): word = np.random.randn(num_classes, 300) print('graph input: random') else: with open(word_file, 'rb') as point: word = np.load(point) print('graph input: loaded from {}'.format(word_file)) self.word = Parameter(torch.from_numpy(word).float()) self.gc0 = GraphConvolution(in_channels, 128, bias=True) self.gc1 = GraphConvolution(128, 256, bias=True) self.gc2 = GraphConvolution(256, 512, bias=True) self.gc3 = GraphConvolution(512, 1024, bias=True) self.gc4 = GraphConvolution(1024, 2048, bias=True) self.gc_relu = nn.LeakyReLU(0.2) self.gc_tanh = nn.Tanh() self.merge_conv0 = nn.Conv2d(num_classes, 128, kernel_size=1, stride=1, bias=False) self.merge_conv1 = nn.Conv2d(num_classes, 256, kernel_size=1, stride=1, bias=False) self.merge_conv2 = nn.Conv2d(num_classes, 512, kernel_size=1, stride=1, bias=False) self.merge_conv3 = nn.Conv2d(num_classes, 1024, kernel_size=1, stride=1, bias=False) self.conv1x1 = conv1x1(in_channels=2048, out_channels=num_classes, bias=True) # self.spatial_attention = SAModule(2048) # self.spatial_attention = SpatialCGNL(2048, 1024) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, dropblock_prob=0.0, is_first=True): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: down_layers = [] if self.avg_down: if dilation == 1: down_layers.append( nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)) else: down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)) down_layers.append( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False)) else: down_layers.append( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)) down_layers.append(norm_layer(planes * block.expansion)) downsample = nn.Sequential(*down_layers) layers = [] if dilation == 1 or dilation == 2: layers.append( block(self.inplanes, planes, stride, downsample=downsample, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, dilation=1, is_first=is_first, rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg, norm_layer=norm_layer, dropblock_prob=dropblock_prob, last_gamma=self.last_gamma, use_se=self.use_se)) elif dilation == 4: layers.append( block(self.inplanes, planes, stride, downsample=downsample, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, dilation=2, is_first=is_first, rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg, norm_layer=norm_layer, dropblock_prob=dropblock_prob, last_gamma=self.last_gamma, use_se=self.use_se)) else: raise RuntimeError("=> unknown dilation size: {}".format(dilation)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block(self.inplanes, planes, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, dilation=dilation, rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg, norm_layer=norm_layer, dropblock_prob=dropblock_prob, last_gamma=self.last_gamma, use_se=self.use_se)) return nn.Sequential(*layers) def forward(self, feature): adj = gen_adj(self.adj).detach() word = self.word.detach() feature = self.conv1(feature) feature = self.bn1(feature) feature = self.relu(feature) feature = self.maxpool(feature) x_raw = self.gc0(word, adj) x = self.gc_tanh(x_raw) feature = merge_gcn_residual(feature, x, self.merge_conv0) feature = self.layer1(feature) x = self.gc_relu(x_raw) x_raw = self.gc1(x, adj) x = self.gc_tanh(x_raw) feature = merge_gcn_residual(feature, x, self.merge_conv1) feature = self.layer2(feature) x = self.gc_relu(x_raw) x_raw = self.gc2(x, adj) x = self.gc_tanh(x_raw) feature = merge_gcn_residual(feature, x, self.merge_conv2) feature = self.layer3(feature) x = self.gc_relu(x_raw) x_raw = self.gc3(x, adj) x = self.gc_tanh(x_raw) feature = merge_gcn_residual(feature, x, self.merge_conv3) feature = self.layer4(feature) # feature = self.spatial_attention(feature) feature_raw = self.global_pool(feature) if self.drop is not None: feature_raw = self.drop(feature_raw) feature = feature_raw.view(feature_raw.size(0), -1) x = self.gc_relu(x_raw) x = self.gc4(x, adj) x = self.gc_tanh(x) x = x.transpose(0, 1) x = torch.matmul(feature, x) y = self.conv1x1(feature_raw) y = y.view(y.size(0), -1) x = x + y return x
class Cifp(nn.Module): """ Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition) """ def __init__(self, in_features, out_features, scale=64.0, margin=0.35): """ Args: in_features: size of each input features out_features: size of each output features scale: norm of input feature margin: margin """ super(Cifp, self).__init__() self.in_features = in_features self.out_features = out_features self.scale = scale self.margin = margin self.kernel = Parameter(torch.Tensor(in_features, out_features)) nn.init.normal_(self.kernel, std=0.01) def forward(self, embeddings, label): cos_theta, origin_cos = calc_logits(embeddings, self.kernel) cos_theta_, _ = calc_logits(embeddings, self.kernel.detach()) mask = torch.zeros_like(cos_theta) mask.scatter_(1, label.view(-1, 1).long(), 1.0) sample_num = embeddings.size(0) tmp_cos_theta = cos_theta - 2 * mask tmp_cos_theta_ = cos_theta_ - 2 * mask target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) target_cos_theta_ = cos_theta_[torch.arange(0, sample_num), label].view(-1, 1) target_cos_theta_m = target_cos_theta - self.margin far = 1 / (self.out_features - 1) # far = 1e-4 topk_mask = torch.greater(tmp_cos_theta, target_cos_theta) topk_sum = torch.sum(topk_mask.to(torch.int32)) dist.all_reduce(topk_sum) far_rank = math.ceil( far * (sample_num * (self.out_features - 1) * dist.get_world_size() - topk_sum)) cos_theta_neg_topk = torch.topk( (tmp_cos_theta - 2 * topk_mask.to(torch.float32)).flatten(), k=far_rank)[0] cos_theta_neg_topk = all_gather_tensor(cos_theta_neg_topk.contiguous()) cos_theta_neg_th = torch.topk(cos_theta_neg_topk, k=far_rank)[0][-1] cond = torch.mul(torch.bitwise_not(topk_mask), torch.greater(tmp_cos_theta, cos_theta_neg_th)) _, cos_theta_neg_topk_index = torch.where(cond) cos_theta_neg_topk = torch.mul(cond.to(torch.float32), tmp_cos_theta) cos_theta_neg_topk_ = torch.mul(cond.to(torch.float32), tmp_cos_theta_) cond = torch.greater(target_cos_theta_m, cos_theta_neg_topk) cos_theta_neg_topk = torch.where(cond, cos_theta_neg_topk, cos_theta_neg_topk_) cos_theta_neg_topk = torch.pow(cos_theta_neg_topk, 2) times = torch.sum(torch.greater(cos_theta_neg_topk, 0).to(torch.float32), dim=1, keepdim=True) times = torch.where(torch.greater(times, 0), times, torch.ones_like(times)) cos_theta_neg_topk = torch.sum(cos_theta_neg_topk, dim=1, keepdim=True) / times target_cos_theta_m = target_cos_theta_m - ( 1 + target_cos_theta_) * cos_theta_neg_topk cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) output = cos_theta * self.scale return output, origin_cos * self.scale
class AsymmetricLaplace(Distribution): def __init__(self, loc=0., scale=1., asymmetry=1., learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(-1) self.n_dims = len(loc) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(-1) if not isinstance(asymmetry, torch.Tensor): asymmetry = torch.tensor(asymmetry).view(-1) self.loc = loc.float() self._scale = utils.softplus_inverse(scale.float()) self._asymmetry = utils.softplus_inverse(asymmetry.float()) if learnable: self.loc = Parameter(self.loc) self._scale = Parameter(self._scale) self._asymmetry = Parameter(self._asymmetry) def log_prob(self, value): s = (value - self.loc).sign() exponent = -(value - self.loc).abs() * self.scale * self.asymmetry.pow(s) coeff = self.scale.log() - (self.asymmetry + (1 / self.asymmetry)).log() return (coeff + exponent).sum(-1) def sample(self, batch_size): U = Uniform(low=-self.asymmetry, high=(1. / self.asymmetry), learnable=False).sample(batch_size) s = U.sign() log_term = (1. - U * s * self.asymmetry.pow(s)).log() return self.loc - (1. / (self.scale * s * self.asymmetry.pow(s))) * log_term def cdf(self, value): s = (value - self.loc).sign() exponent = -(value - self.loc).abs() * self.scale * self.asymmetry.pow(s) exponent = exponent.exp() return (value > self.loc).float() - s * self.asymmetry.pow(1 - s) / ( 1 + self.asymmetry.pow(2)) * exponent # def icdf(self, value): # return def entropy(self): return (utils.e * (1 + self.asymmetry.pow(2)) / (self.asymmetry * self.scale)).log().sum() @property def expectation(self): return self.loc + ((1 - self.asymmetry.pow(2)) / (self.scale * self.asymmetry)) @property def variance(self): return (1 + self.asymmetry.pow(4)) / (self.scale.pow(2) * self.asymmetry.pow(2)) @property def mode(self): return self.loc @property def median(self): return self.loc + (self.asymmetry / self.scale) * ( (1 + self.asymmetry.pow(2)) / (2 * self.asymmetry.pow(2))).log() @property def skewness(self): return (2 * (1 - self.asymmetry.pow(6))) / (1 + self.asymmetry.pow(4)).pow( 3. / 2.) @property def kurtosis(self): return (6 * (1 + self.asymmetry.pow(8))) / (1 + self.asymmetry.pow(4)).pow(2) @property def scale(self): return softplus(self._scale) @property def asymmetry(self): return softplus(self._asymmetry) def get_parameters(self): if self.n_dims == 1: return { 'loc': self.loc.item(), 'scale': self.scale.item(), 'asymmetry': self.asymmetry.item() } return { 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy(), 'asymmetry': self.asymmetry.detach().numpy() }
class ParameterFromRuntimeMinZeroPoint(brevitas.jit.ScriptModule): __constants__ = ['stats_permute_dims', 'collect_stats_steps', 'momentum'] def __init__( self, collect_stats_steps: int, int_quant: Module, stats_reduce_dim: Optional[int], zero_point_shape: Tuple[int, ...], zero_point_stats_input_view_shape_impl: Module, zero_point_stats_permute_dims: Optional[Tuple[int, ...]] = None, zero_point_stats_momentum: Optional[float] = DEFAULT_MOMENTUM) -> None: super(ParameterFromRuntimeMinZeroPoint, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' if zero_point_shape != SCALAR_SHAPE and zero_point_stats_permute_dims is None: raise RuntimeError("Per channel runtime stats require a permute shape") self.collect_stats_steps = collect_stats_steps self.counter: int = brevitas.jit.Attribute(0, int) self.stats_permute_dims = zero_point_stats_permute_dims self.stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl self.momentum = zero_point_stats_momentum self.value = Parameter(torch.full(zero_point_shape, 0.0)) self.register_buffer('buffer', torch.full(zero_point_shape, 0.0)) self.negative_min_or_zero = NegativeMinOrZero(stats_reduce_dim) self.int_quant = int_quant @brevitas.jit.script_method def training_forward(self, x) -> Tensor: if self.counter < self.collect_stats_steps: if self.stats_permute_dims is not None: x = x.permute(*self.stats_permute_dims).contiguous() stats_input = self.stats_input_view_shape_impl(x) stats = self.negative_min_or_zero(stats_input) new_counter = self.counter + 1 if self.counter == 0: inplace_tensor_add(self.buffer, stats.detach()) else: inplace_momentum_update( self.buffer, stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter # work around find_unusued_parameters=True in DDP out = stats + 0. * self.value elif self.counter == self.collect_stats_steps: inplace_tensor_add(self.value.detach(), self.buffer) self.counter = self.counter + 1 out = self.value else: out = self.value return out @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: if self.training: out = self.training_forward(x) else: if self.counter <= self.collect_stats_steps: out = self.buffer else: out = self.value out = abs_binary_sign_grad(out) min_int = self.int_quant.min_int(bit_width) out = self.int_quant.to_int(scale, min_int, bit_width, out) return out def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(ParameterFromRuntimeMinZeroPoint, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) value_key = prefix + 'value' # Pytorch stores training flag as a buffer with JIT enabled training_key = prefix + 'training' if training_key in missing_keys: missing_keys.remove(training_key) # disable stats collection when a pretrained value is loaded if value_key not in missing_keys: self.counter = self.collect_stats_steps + 1 if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key)
class Laplace(Distribution): def __init__(self, loc=0., scale=1., learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(-1) self.n_dims = len(loc) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(-1) self.loc = loc.float() self._scale = utils.softplus_inverse(scale.float()) if learnable: self.loc = Parameter(self.loc) self._scale = Parameter(self._scale) def log_prob(self, value): return (-(2. * self.scale).log() - ((value - self.loc).abs() / self.scale)).sum(-1) def sample(self, batch_size): return dists.Laplace(self.loc, self.scale).rsample((batch_size, )) def cdf(self, value): return 0.5 - 0.5 * (value - self.loc).sign() * ( -(value - self.loc).abs() / self.scale).expm1() def icdf(self, value): term = value - 0.5 return self.loc - self.scale * term.sign() * (-2 * term.abs()).log1p() def entropy(self): return 1 + (2 * self.scale).log() def kl(self, other): if isinstance(other, Laplace): scale_ratio = self.scale / other.scale loc_abs_diff = (self.loc - other.loc).abs() t1 = -scale_ratio.log() t2 = loc_abs_diff / other.scale t3 = scale_ratio * (-loc_abs_diff / self.scale).exp() return (t1 + t2 + t3 - 1.).sum() return None @property def expectation(self): return self.loc @property def variance(self): return 2 * self.scale.pow(2) @property def median(self): return self.loc @property def stddev(self): return (2**0.5) * self.scale @property def mode(self): return self.loc @property def skewness(self): return torch.tensor(0.).float() @property def kurtosis(self): return torch.tensor(3.).float() @property def scale(self): return softplus(self._scale) def get_parameters(self): if self.n_dims == 1: return {'loc': self.loc.item(), 'scale': self.scale.item()} return { 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy() }
class StudentT(Distribution): def __init__(self, df=1., loc=0., scale=1., learnable=True): super().__init__() if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc).view(-1) self.n_dims = len(loc) if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale).view(-1) if not isinstance(df, torch.Tensor): df = torch.tensor(df).view(-1) self.loc = loc.float() self._scale = utils.softplus_inverse(scale.float()) self._df = utils.softplus_inverse(df.float()) if learnable: self.loc = Parameter(self.loc) self._scale = Parameter(self._scale) self._df = Parameter(self._df) def log_prob(self, value): model = dists.StudentT(self.df, self.loc, self.scale) return model.log_prob(value).sum(-1) def sample(self, batch_size): model = dists.StudentT(self.df, self.loc, self.scale) return model.rsample((batch_size, )) def entropy(self): return dists.StudentT(self.df, self.loc, self.scale).entropy() @property def expectation(self): return dists.StudentT(self.df, self.loc, self.scale).mean @property def mode(self): return self.expectation @property def variance(self): return dists.StudentT(self.df, self.loc, self.scale).variance @property def scale(self): return softplus(self._scale) @property def df(self): return softplus(self._df) def get_parameters(self): if self.n_dims == 1: return { 'loc': self.loc.item(), 'scale': self.scale.item(), 'df': self.df.item() } return { 'loc': self.loc.detach().numpy(), 'scale': self.scale.detach().numpy(), 'df': self.df.detach().numpy() }
class MultiheadAttention_v2(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, stable_init_ratio=0., self_attention=False, encoder_decoder_attention=False): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ 'value to be of the same size' if self.qkv_same_dim: self.in_proj_weight = Parameter( torch.Tensor(3 * embed_dim, embed_dim)) # embed_dim // num_heads = head_dim else: self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.k_mask = torch.ones(embed_dim, self.kdim, dtype=torch.half, device=torch.device('cuda:0')) self.v_mask = torch.ones(embed_dim, self.vdim, dtype=torch.half, device=torch.device('cuda:0')) self.q_mask = torch.ones(embed_dim, embed_dim, dtype=torch.half, device=torch.device('cuda:0')) if bias: self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj_mask = torch.ones(embed_dim, embed_dim, dtype=torch.half, device=torch.device('cuda:0')) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self.num_copied_heads = int(self.num_heads * stable_init_ratio) self.onnx_trace = False self.enable_torch_version = False if hasattr(F, "multi_head_attention_forward"): self.enable_torch_version = True else: self.enable_torch_version = False def prepare_for_onnx_export_(self): self.onnx_trace = True def reset_parameters(self, prev_weight=None, copied_heads=None): if self.qkv_same_dim: nn.init.xavier_uniform_(self.in_proj_weight) else: nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.v_proj_weight) nn.init.xavier_uniform_(self.q_proj_weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.in_proj_bias is not None: nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) # if self.num_copied_heads > 0 and prev_weight is not None: # if self.qkv_same_dim: # same_dim_heads = np.concatenate((copied_heads, copied_heads + self.embed_dim, \ # copied_heads + 2 * self.embed_dim), axis=None) # self.in_proj_weight[same_dim_heads, :].data = prev_weight["in"][same_dim_heads, :].data # else: # self.k_proj_weight[copied_heads, :].data = prev_weight["k"][copied_heads, :].data # self.v_proj_weight[copied_heads, :].data = prev_weight["v"][copied_heads, :].data # self.q_proj_weight[copied_heads, :].data = prev_weight["q"][copied_heads, :].data # self.out_proj.weight[copied_heads, :].data = prev_weight["out"][copied_heads, :].data # if self.qkv_same_dim: # self.q_mask[copied_heads, :] = 0 # mask = self.q_mask.repeat(3, 1) # self.in_proj_weight = Parameter(self.in_proj_weight * mask + prev_weight["in"] * (1 - mask)) # else: # self.k_mask[copied_heads, :], self.v_mask[copied_heads, :] = 0, 0 # self.k_proj_weight = Parameter(self.k_proj_weight * self.k_mask + prev_weight["k"]* (1 - self.k_mask)) # self.v_proj_weight = Parameter(self.v_proj_weight * self.v_mask + prev_weight["v"] * (1 - self.v_mask)) # self.q_proj_weight = Parameter(self.q_proj_weight * self.q_mask + prev_weight["q"] * (1 - self.q_mask)) # self.out_proj_mask[copied_heads, :] = 0 # self.out_proj.weight = Parameter(self.out_proj.weight * self.out_proj_mask + prev_weight["out"] * (1 - self.out_proj_mask)) def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, need_weights=False, static_kv=False, attn_mask=None, prev_weight=None): """Input shape: Time x Batch x Channel Timesteps can be masked by supplying a T x T mask in the `attn_mask` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] copied_heads = None if self.num_copied_heads > 0 and prev_weight is not None: copied_idx = np.random.choice(self.num_heads, self.num_copied_heads, replace=False) copied_heads = np.reshape( np.array([ list(range(s, e)) for (s, e) in zip(copied_idx * self.head_dim, (copied_idx + 1) * self.head_dim) ]), -1) self.reset_parameters(prev_weight, copied_heads) if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: if self.qkv_same_dim: return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask) else: return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: # self-attention q, k, v = self.in_proj_qkv(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.in_proj_q(query) if key is None: assert value is None k = v = None else: k = self.in_proj_k(key) v = self.in_proj_v(key) else: q = self.in_proj_q(query) k = self.in_proj_k(key) v = self.in_proj_v(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1) ], dim=1) q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: v = torch.cat((prev_value, v), dim=1) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) # q: [bz * num_heads, tgt_len, head_dim] # k, v: [bz * num_heads, src_len, head_dim] # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.shape == torch.Size( []): key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask) ], dim=1) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if self.onnx_trace: attn_weights = torch.where( key_padding_mask.unsqueeze(1).unsqueeze(2), torch.Tensor([float("-Inf")]), attn_weights.float()).type_as(attn_weights) else: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace, ).type_as(attn_weights) if self.num_copied_heads > 0 and prev_weight is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights[:, copied_idx, :, :] = prev_weight[ "nonavg"][:, copied_idx, :, :] attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attention_weights = {} if need_weights: # un-averaged attention weight attention_weights["nonavg"] = attn_weights.view( bsz, self.num_heads, tgt_len, src_len).detach().clone() # attention_weights["nonavg"] = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) # [bz, num_heads, tgt_len, src_len] attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: # average attention weights over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attention_weights["avg"] = attn_weights.sum( dim=1) / self.num_heads # [bz, tgt_len, src_len] # projection weight if self.qkv_same_dim: attention_weights["in"] = self.in_proj_weight.detach().clone() else: attention_weights["k"], attention_weights["q"], attention_weights["v"] =\ self.k_proj_weight.detach().clone(), self.q_proj_weight.detach().clone(), self.v_proj_weight.detach().clone() attention_weights["out"] = self.out_proj.weight.detach().clone() return attn, attention_weights def in_proj_qkv(self, query): return self._in_proj(query).chunk(3, dim=-1) def in_proj_q(self, query): if self.qkv_same_dim: return self._in_proj(query, end=self.embed_dim) else: bias = self.in_proj_bias if bias is not None: bias = bias[:self.embed_dim] return F.linear(query, self.q_proj_weight, bias) def in_proj_k(self, key): if self.qkv_same_dim: return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) else: weight = self.k_proj_weight bias = self.in_proj_bias if bias is not None: bias = bias[self.embed_dim:2 * self.embed_dim] return F.linear(key, weight, bias) def in_proj_v(self, value): if self.qkv_same_dim: return self._in_proj(value, start=2 * self.embed_dim) else: weight = self.v_proj_weight bias = self.in_proj_bias if bias is not None: bias = bias[2 * self.embed_dim:] return F.linear(value, weight, bias) def _in_proj(self, input, start=0, end=None): weight = self.in_proj_weight bias = self.in_proj_bias weight = weight[start:end, :] if bias is not None: bias = bias[start:end] return F.linear(input, weight, bias) def reorder_incremental_state(self, incremental_state, new_order): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: for k in input_buffer.keys(): input_buffer[k] = input_buffer[k].index_select(0, new_order) self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'attn_state', ) or {} def _set_input_buffer(self, incremental_state, buffer): utils.set_incremental_state( self, incremental_state, 'attn_state', buffer, ) def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): return attn_weights