Exemple #1
0
class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule):
    """
    ScriptModule implementation of a learned scale factor initialized from runtime statistics.
    The implementation works in two phases. During the first phase, statistics are collected in
    the same fashion as batchnorm, meaning that while the module is in training mode a set of per-batch
    statistics are computed and returned, while in background an average of them is retained and returned 
    in inference mode. During the second phase, the average accumulated during the first
    phase is used to initialize a learned torch.nn.Parameter, and then the behaviour is the same
    as ParameterScaling.

    Args:
        collect_stats_steps (int): Number of calls to the forward method in training mode to collect statistics for.
        scaling_stats_impl (Module): Implementation of the statistics computed during the collection phase.
        scaling_stats_input_view_shape_impl (Module): Implementation of the view applied to the runtime
            input during the statistics collection phase. Default: OverBatchOverTensorView().
        scaling_shape (Tuple[int, ...]): shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE.
        restrict_scaling_impl (Module): restrict the learned scale factor according to some criteria. Default: None
            input before going into scaling_stats_input_view_shape_impl. Default: None
        scaling_stats_momentum: float = Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM.
        scaling_min_val (float): force a lower-bound on the learned scale factor. Default: None.

    Returns:
        Tensor: learned scale factor wrapped in a float torch.tensor.

    Raises:
        RuntimeError: if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None

    Examples:
        >>> scaling_impl = ParameterFromRuntimeStatsScaling(collect_stats_steps=1, scaling_stats_impl=AbsMax())
        >>> scaling_impl.training
        True
        >>> x = torch.arange(-3, 2, 0.1)
        >>> scaling_impl(x)
        tensor(3.)
        >>> scaling_impl(torch.randn_like(x))
        tensor(3., grad_fn=<AbsBinarySignGradFnBackward>)

    Note:
        Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining
        from a floating point state dict.

    Note:
        Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == 'PARAMETER_FROM_STATS'
        == 'parameter_from_stats' when applied to runtime values (inputs/outputs/activations) in higher-level APIs.
    """
    __constants__ = ['collect_stats_steps', 'momentum']

    def __init__(
            self,
            collect_stats_steps: int,
            scaling_stats_impl: Module,
            scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(),
            scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
            restrict_scaling_impl: Optional[Module] = None,
            scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
            scaling_min_val: Optional[float] = None) -> None:
        super(ParameterFromRuntimeStatsScaling, self).__init__()
        assert collect_stats_steps > 0, 'Steps should be more than 0'
        self.collect_stats_steps = collect_stats_steps
        self.counter: int = brevitas.jit.Attribute(0, int)
        self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
        self.stats = _Stats(scaling_stats_impl, scaling_shape)
        self.momentum = scaling_stats_momentum
        self.register_buffer('buffer', torch.full(scaling_shape, 1.0))
        self.value = Parameter(torch.full(scaling_shape, 1.0))
        self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
        if restrict_scaling_impl is not None:
            self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
            self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
        else:
            self.restrict_inplace_preprocess = Identity()
            self.restrict_preprocess = Identity()
    
    @brevitas.jit.script_method
    def training_forward(self, stats_input: Tensor) -> Tensor:
        if self.counter < self.collect_stats_steps:
            stats_input = self.stats_input_view_shape_impl(stats_input)
            stats = self.stats(stats_input)
            new_counter = self.counter + 1
            if self.counter == 0:
                _inplace_init(self.buffer, stats.detach())
            else:
                _inplace_update(self.buffer, stats.detach(), self.momentum, self.counter, new_counter)
            self.counter = new_counter
            return stats
        elif self.counter == self.collect_stats_steps:
            self.restrict_inplace_preprocess(self.buffer)
            _inplace_init(self.value.detach(), self.buffer)
            self.counter = self.counter + 1
            return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
        else:
            return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))

    @brevitas.jit.script_method
    def forward(self, stats_input: Tensor) -> Tensor:
        if self.training:
            return self.training_forward(stats_input)
        else:
            if self.counter <= self.collect_stats_steps:
                out = self.buffer
                out = self.restrict_preprocess(out)
            else:
                out = self.value
            out = abs_binary_sign_grad(self.restrict_clamp_scaling(out))
        return out

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        output_dict = super(ParameterFromRuntimeStatsScaling, self).state_dict(
            destination, prefix, keep_vars)        
        # Avoid saving the buffer
        del output_dict[prefix + 'buffer']
        # Avoid saving the init value
        if self.counter == 0:
            del output_dict[prefix + 'value']
        # Save buffer into value for any non-zero number of collection steps
        elif self.counter <= self.collect_stats_steps:
            output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer)
        return output_dict

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        value_key = prefix + 'value'
        # Buffer is supposed to be always missing
        missing_keys.remove(prefix + 'buffer')
        # Retrocompatibility with older ParameterScaling, for when scaling impl is switched over
        retrocomp_value_key = prefix + 'learned_value'
        if retrocomp_value_key in state_dict:
            state_dict[value_key] = state_dict.pop(retrocomp_value_key)
        # Pytorch stores training flag as a buffer with JIT enabled
        training_key = prefix + 'training'
        if training_key in missing_keys:
            missing_keys.remove(training_key)
        # disable stats collection when a pretrained value is loaded
        if value_key not in missing_keys:
            self.counter = self.collect_stats_steps + 1
        if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
            missing_keys.remove(value_key)
Exemple #2
0
class InfiniteMixtureModel(Distribution):
    def __init__(self,
                 df,
                 loc,
                 scale,
                 loc_learnable=True,
                 scale_learnable=True,
                 df_learnable=True):
        super().__init__()
        self.loc = torch.tensor(loc).view(-1)
        self.n_dims = len(self.loc)
        if loc_learnable:
            self.loc = Parameter(self.loc)
        self._scale = utils.softplus_inverse(torch.tensor(scale).view(-1))
        if scale_learnable:
            self._scale = Parameter(self._scale)
        self._df = utils.softplus_inverse(torch.tensor(df).view(-1))
        if df_learnable:
            self._df = Parameter(self._df)

    def sample(self, batch_size, return_latents=False):
        weight_model = Gamma(self.df / 2, self.df / 2, learnable=False)
        latent_samples = weight_model.sample(batch_size)
        normal_model = Normal(self.loc.expand(batch_size),
                              (self.scale / latent_samples).squeeze(1),
                              learnable=False)
        if return_latents:
            return normal_model.sample(1).squeeze(0).unsqueeze(
                1), latent_samples
        else:
            return normal_model.sample(1).squeeze(0).unsqueeze(1)

    def log_prob(self, samples, latents=None):
        if latents is None:
            raise NotImplementedError(
                "InfiniteMixtureModel log_prob not implemented")
        weight_model = Gamma(self.df / 2, self.df / 2, learnable=False)
        normal_model = Normal(self.loc.expand(latents.size(0)),
                              (self.scale / latents).squeeze(1),
                              learnable=False)
        return normal_model.log_prob(samples) + weight_model.log_prob(latents)

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def df(self):
        return softplus(self._df)

    @property
    def has_latents(self):
        return True

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                "loc": self.loc.item(),
                "scale": self.scale.item(),
                "df": self.df.item(),
            }
        return {
            "loc": self.loc.detach().numpy(),
            "scale": self.scale.detach().numpy(),
            "df": self.df.detach().numpy(),
        }
Exemple #3
0
class SpatialTransformerPooled3d(nn.Module):

    def __init__(self, in_shape, outdims, pool_steps=1, positive=False, bias=True,
                 init_range=.05, kernel_size=2, stride=2, grid=None, stop_grad=False):
        super().__init__()
        self._pool_steps = pool_steps
        self.in_shape = in_shape
        c, t, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        if grid is None:
            self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        else:
            self.grid = grid
        self.features = Parameter(torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims))
        self.register_buffer('mask', torch.ones_like(self.features))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter('bias', bias)
        else:
            self.register_parameter('bias', None)

        self.avg = nn.AvgPool2d(kernel_size, stride=stride, count_include_pad=False)
        self.init_range = init_range
        self.initialize()
        self.stop_grad = stop_grad

    @property
    def pool_steps(self):
        return self._pool_steps

    @pool_steps.setter
    def pool_steps(self, value):
        assert value >= 0 and int(value) - value == 0, 'new pool steps must be a non-negative integer'
        if value != self._pool_steps:
            print('Resizing readout features')
            c, t, w, h = self.in_shape
            outdims = self.outdims
            self._pool_steps = int(value)
            self.features = Parameter(torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims))
            self.mask = torch.ones_like(self.features)
            self.features.data.fill_(1 / self.in_shape[0])

    def initialize(self, init_noise=1e-3, grid=True):
        # randomly pick centers within the spatial map

        self.features.data.fill_(1 / self.in_shape[0])
        if self.bias is not None:
            self.bias.data.fill_(0)
        if grid:
            self.grid.data.uniform_(-self.init_range, self.init_range)

    def feature_l1(self, average=True, subs_idx=None):
        subs_idx = subs_idx if subs_idx is not None else slice(None)
        if average:
            return self.features[..., subs_idx].abs().mean()
        else:
            return self.features[..., subs_idx].abs().sum()

    def reset_fisher_prune_scores(self):
        self._prune_n = 0
        self._prune_scores = self.features.detach() * 0

    def update_fisher_prune_scores(self):
        self._prune_n += 1
        if self.features.grad is None:
            raise ValueError('You need to run backward first')
        self._prune_scores += (0.5 * self.features.grad.pow(2) * self.features.pow(2)).detach()

    @property
    def fisher_prune_scores(self):
        return self._prune_scores / self._prune_n

    def prune(self):
        idx = (self.fisher_prune_scores + 1e6 * (1 - self.mask)).squeeze().argmin(dim=0)
        nt = idx.new
        seq = nt(np.arange(len(idx)))
        self.mask[:, idx, :, seq] = 0
        self.features.data[:, idx, :, seq] = 0

    def forward(self, x, shift=None, subs_idx=None):
        if self.stop_grad:
            x = x.detach()

        self.features.data *= self.mask

        if self.positive:
            positive(self.features)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)

        N, c, t, w, h = x.size()
        m = self._pool_steps + 1
        if subs_idx is not None:
            feat = self.features[..., subs_idx].contiguous()
            outdims = feat.size(-1)
            feat = feat.view(1, m * c, outdims)
            grid = self.grid[:, subs_idx, ...]
        else:
            grid = self.grid
            feat = self.features.view(1, m * c, self.outdims)
            outdims = self.outdims

        if shift is None:
            grid = grid.expand(N * t, outdims, 1, 2)
        else:
            grid = grid.expand(N, outdims, 1, 2)
            grid = torch.stack([grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1)
            grid = grid.contiguous().view(-1, outdims, 1, 2)
        z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h)
        pools = [F.grid_sample(z, grid)]
        for i in range(self._pool_steps):
            z = self.avg(z)
            pools.append(F.grid_sample(z, grid))
        y = torch.cat(pools, dim=1)
        y = (y.squeeze(-1) * feat).sum(1).view(N, t, outdims)

        if self.bias is not None:
            if subs_idx is None:
                y = y + self.bias
            else:
                y = y + self.bias[subs_idx]

        return y

    def __repr__(self):
        c, _, w, h = self.in_shape
        r = self.__class__.__name__ + \
            ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')'
        if self.bias is not None:
            r += ' with bias'
        if self.stop_grad:
            r += ', stop_grad=True'
        r += '\n'

        for ch in self.children():
            r += '  -> ' + ch.__repr__() + '\n'
        return r
Exemple #4
0
class _NMF(Base):
    def __init__(self, W_size, H_size, rank):
        super().__init__()
        self.rank = rank
        self.W = Parameter(torch.rand(*W_size))
        self.H = Parameter(torch.rand(*H_size))

    def forward(self, H=None, W=None):
        if H is None:
            H = self.H
        if W is None:
            W = self.W
        return self.reconstruct(H, W)

    def reconstruct(self, H, W):
        raise NotImplementedError

    def get_W_positive(self, WH, beta, H_sum) -> (torch.Tensor, None or torch.Tensor):
        raise NotImplementedError

    def get_H_positive(self, WH, beta, W_sum) -> (torch.Tensor, None or torch.Tensor):
        raise NotImplementedError

    def fit(self,
            V,
            W=None,
            H=None,
            update_W=True,
            update_H=True,
            beta=1,
            tol=1e-5,
            max_iter=200,
            verbose=0,
            initial='random',
            alpha=0,
            l1_ratio=0
            ):

        V = self.fix_neg(V)

        if W is None:
            pass  # will do special initialization in thre future
        else:
            self.W.data.copy_(W)
            self.W.requires_grad = update_W

        if H is None:
            pass
        else:
            self.H.data.copy_(H)
            self.H.requires_grad = update_H

        if beta < 1:
            gamma = 1 / (2 - beta)
        elif beta > 2:
            gamma = 1 / (beta - 1)
        else:
            gamma = 1

        l1_reg = alpha * l1_ratio
        l2_reg = alpha * (1 - l1_ratio)

        loss_scale = torch.prod(torch.tensor(V.shape)).float()

        H_sum, W_sum = None, None
        with tqdm(total=max_iter, disable=not verbose) as pbar:
            for n_iter in range(max_iter):
                if self.W.requires_grad:
                    self.zero_grad()
                    WH = self.reconstruct(self.H.detach(), self.W)
                    loss = Beta_divergence(self.fix_neg(WH), V, beta)
                    loss.backward()

                    with torch.no_grad():
                        positive_comps, H_sum = self.get_W_positive(WH, beta, H_sum)
                        _mu_update(self.W, positive_comps, gamma, l1_reg, l2_reg)
                    W_sum = None

                if self.H.requires_grad:
                    self.zero_grad()
                    WH = self.reconstruct(self.H, self.W.detach())
                    loss = Beta_divergence(self.fix_neg(WH), V, beta)
                    loss.backward()

                    with torch.no_grad():
                        positive_comps, W_sum = self.get_H_positive(WH, beta, W_sum)
                        _mu_update(self.H, positive_comps, gamma, l1_reg, l2_reg)
                    H_sum = None

                loss = loss.div_(loss_scale).item()

                pbar.set_postfix(loss=loss)
                # pbar.set_description('Beta loss=%.4f' % error)
                pbar.update()

                if not n_iter:
                    loss_init = loss
                elif (previous_loss - loss) / loss_init < tol:
                    break
                previous_loss = loss

        return n_iter

    def fit_transform(self, *args, **kwargs):
        n_iter = self.fit(*args, **kwargs)
        return n_iter, self.forward()
Exemple #5
0
class ArcMarginProduct_virface(nn.Module):
    def __init__(self,
                 in_features=512,
                 out_features=84281,
                 s=32,
                 m=0.5,
                 easy_margin=False,
                 device='cuda'):
        super(ArcMarginProduct_virface, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.dropout = nn.Dropout(0.5)
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        self.device = device

    def forward(self,
                x,
                label,
                unlabel_x=None,
                unlabel_aug=None,
                overlap=False):
        if unlabel_x is not None:
            # filter overlap
            if overlap:
                unlabel_dot_w = F.linear(F.normalize(unlabel_x.detach()),
                                         F.normalize(
                                             self.weight.detach())) * self.s
                prob = F.softmax(unlabel_dot_w, dim=1)
                max_prob = torch.max(prob, dim=1)[0]
                idx_lt = max_prob.lt(0.8)

                idx = idx_lt
            else:
                idx = torch.ones(unlabel_x.shape[0]).bool().cuda()

            unlabel_data = unlabel_x[idx]

            weight_all = torch.cat(
                [F.normalize(self.weight),
                 F.normalize(unlabel_data)], dim=0)
            weight_all_fix = torch.cat(
                [F.normalize(self.weight),
                 F.normalize(unlabel_data).detach()],
                dim=0)
        else:  # no unlabel data, return arcface
            weight_all = F.normalize(self.weight)

        # virclass
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        if unlabel_x is not None:
            x = self.dropout(x)
        cosine_label = F.linear(F.normalize(x), weight_all)
        sine_label = 1.0 - torch.pow(cosine_label, 2)
        sine_label = torch.where(
            sine_label > 0, sine_label,
            torch.zeros(sine_label.size(), device=self.device))
        sine_label = torch.sqrt(sine_label)

        phi_label = cosine_label * self.cos_m - sine_label * self.sin_m
        if self.easy_margin:
            phi_label = torch.where(cosine_label > 0, phi_label, cosine_label)
        else:
            phi_label = torch.where((cosine_label - self.th) > 0, phi_label,
                                    cosine_label - self.mm)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine_label.size(), device=self.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)

        output_label = (one_hot * phi_label) + ((1.0 - one_hot) * cosine_label)
        output_label *= self.s

        if unlabel_aug is None or unlabel_x is None:  # no aug, return virclass or arcface
            return output_label, None, None

        # virinstance
        aug_data = torch.cat([
            unlabel_aug[i * unlabel_x.shape[0]:(i + 1) *
                        unlabel_x.shape[0]][idx]
            for i in range(int(unlabel_aug.shape[0] / unlabel_x.shape[0]))
        ],
                             dim=0)
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine_unlabel = F.linear(F.normalize(aug_data), weight_all_fix)
        sine_unlabel = 1.0 - torch.pow(cosine_unlabel, 2)
        sine_unlabel = torch.where(
            sine_unlabel > 0, sine_unlabel,
            torch.zeros(sine_unlabel.size(), device=self.device))
        sine_unlabel = torch.sqrt(sine_unlabel)

        phi_unlabel = cosine_unlabel * self.cos_m - sine_unlabel * self.sin_m
        if self.easy_margin:
            phi_unlabel = torch.where(cosine_unlabel > 0, phi_unlabel,
                                      cosine_unlabel)
        else:
            phi_unlabel = torch.where((cosine_unlabel - self.th) > 0,
                                      phi_unlabel, cosine_unlabel - self.mm)

        # --------------------------- convert label to one-hot ---------------------------
        unlabel_label = torch.arange(self.out_features,
                                     self.out_features + unlabel_data.size(0),
                                     device=self.device).repeat(
                                         int(unlabel_aug.shape[0] /
                                             unlabel_x.shape[0]))

        one_hot_unlabel = torch.zeros(cosine_unlabel.size(),
                                      device=self.device)
        one_hot_unlabel.scatter_(1, unlabel_label.view(-1, 1).long(), 1)

        output_unlabel = (one_hot_unlabel * phi_unlabel) + (
            (1.0 - one_hot_unlabel) * cosine_unlabel)
        output_unlabel *= self.s

        return output_label, output_unlabel, unlabel_label
Exemple #6
0
class GPRegression(GPModel):
    r"""
    Gaussian Process Regression model.

    The core of a Gaussian Process is a covariance function :math:`k` which governs
    the similarity between input points. Given :math:`k`, we can establish a
    distribution over functions :math:`f` by a multivarite normal distribution

    .. math:: p(f(X)) = \mathcal{N}(0, k(X, X)),

    where :math:`X` is any set of input points and :math:`k(X, X)` is a covariance
    matrix whose entries are outputs :math:`k(x, z)` of :math:`k` over input pairs
    :math:`(x, z)`. This distribution is usually denoted by

    .. math:: f \sim \mathcal{GP}(0, k).

    .. note:: Generally, beside a covariance matrix :math:`k`, a Gaussian Process can
        also be specified by a mean function :math:`m` (which is a zero-value function
        by default). In that case, its distribution will be

        .. math:: p(f(X)) = \mathcal{N}(m(X), k(X, X)).

    Given inputs :math:`X` and their noisy observations :math:`y`, the Gaussian Process
    Regression model takes the form

    .. math::
        f &\sim \mathcal{GP}(0, k(X, X)),\\
        y & \sim f + \epsilon,

    where :math:`\epsilon` is Gaussian noise.

    .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training,
        :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number
        of train inputs.

    Reference:

    [1] `Gaussian Processes for Machine Learning`,
    Carl E. Rasmussen, Christopher K. I. Williams

    :param torch.Tensor X: A input data for training. Its first dimension is the number
        of data points.
    :param torch.Tensor y: An output data for training. Its last dimension is the
        number of data points.
    :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which
        is the covariance function :math:`k`.
    :param torch.Tensor noise: Variance of Gaussian noise of this model.
    :param callable mean_function: An optional mean function :math:`m` of this Gaussian
        process. By default, we use zero mean.
    :param float jitter: A small positive term which is added into the diagonal part of
        a covariance matrix to help stablize its Cholesky decomposition.
    """
    def __init__(self,
                 X,
                 y,
                 kernel,
                 noise=None,
                 mean_function=None,
                 jitter=1e-6):
        super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter)

        noise = self.X.new_tensor(1.) if noise is None else noise
        self.noise = Parameter(noise)
        self.set_constraint("noise", torchdist.constraints.positive)

    @autoname.scope(prefix="GPR")
    def model(self):
        self.set_mode("model")

        N = self.X.size(0)
        Kff = self.kernel(self.X)
        Kff.view(-1)[::N +
                     1] += self.jitter + self.noise  # add noise to diagonal
        Lff = Kff.cholesky()

        zero_loc = self.X.new_zeros(self.X.size(0))
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = Lff.pow(2).sum(dim=-1)
            return f_loc, f_var
        else:
            return pyro.sample(
                "y",
                dist.MultivariateNormal(f_loc, scale_tril=Lff).expand_by(
                    self.y.shape[:-1]).to_event(self.y.dim() - 1),
                obs=self.y)

    @autoname.scope(prefix="GPR")
    def guide(self):
        self.set_mode("guide")

    def forward(self, Xnew, full_cov=False, noiseless=True):
        r"""
        Computes the mean and covariance matrix (or variance) of Gaussian Process
        posterior on a test input data :math:`X_{new}`:

        .. math:: p(f^* \mid X_{new}, X, y, k, \epsilon) = \mathcal{N}(loc, cov).

        .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with
            kernel's parameters have been learned from a training procedure (MCMC or
            SVI).

        :param torch.Tensor Xnew: A input data for testing. Note that
            ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
        :param bool full_cov: A flag to decide if we want to predict full covariance
            matrix or just variance.
        :param bool noiseless: A flag to decide if we want to include noise in the
            prediction output or not.
        :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        self._check_Xnew_shape(Xnew)
        self.set_mode("guide")

        N = self.X.size(0)
        Kff = self.kernel(self.X).contiguous()
        Kff.view(
            -1)[::N +
                1] += self.jitter + self.noise  # add noise to the diagonal
        Lff = Kff.cholesky()

        y_residual = self.y - self.mean_function(self.X)
        loc, cov = conditional(Xnew,
                               self.X,
                               self.kernel,
                               y_residual,
                               None,
                               Lff,
                               full_cov,
                               jitter=self.jitter)

        if full_cov and not noiseless:
            M = Xnew.size(0)
            cov = cov.contiguous()
            cov.view(-1, M * M)[:, ::M +
                                1] += self.noise  # add noise to the diagonal
        if not full_cov and not noiseless:
            cov = cov + self.noise

        return loc + self.mean_function(Xnew), cov

    def iter_sample(self, noiseless=True):
        r"""
        Iteratively constructs a sample from the Gaussian Process posterior.

        Recall that at test input points :math:`X_{new}`, the posterior is
        multivariate Gaussian distributed with mean and covariance matrix
        given by :func:`forward`.

        This method samples lazily from this multivariate Gaussian. The advantage
        of this approach is that later query points can depend upon earlier ones.
        Particularly useful when the querying is to be done by an optimisation
        routine.

        .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with
            kernel's parameters have been learned from a training procedure (MCMC or
            SVI).

        :param bool noiseless: A flag to decide if we want to add sampling noise
            to the samples beyond the noise inherent in the GP posterior.
        :returns: sampler
        :rtype: function
        """
        noise = self.noise.detach()
        X = self.X.clone().detach()
        y = self.y.clone().detach()
        N = X.size(0)
        Kff = self.kernel(X).contiguous()
        Kff.view(-1)[::N + 1] += noise  # add noise to the diagonal

        outside_vars = {"X": X, "y": y, "N": N, "Kff": Kff}

        def sample_next(xnew, outside_vars):
            """Repeatedly samples from the Gaussian process posterior,
            conditioning on previously sampled values.
            """
            warn_if_nan(xnew)

            # Variables from outer scope
            X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars[
                "Kff"]

            # Compute Cholesky decomposition of kernel matrix
            Lff = Kff.cholesky()
            y_residual = y - self.mean_function(X)

            # Compute conditional mean and variance
            loc, cov = conditional(xnew,
                                   X,
                                   self.kernel,
                                   y_residual,
                                   None,
                                   Lff,
                                   False,
                                   jitter=self.jitter)
            if not noiseless:
                cov = cov + noise

            ynew = torchdist.Normal(loc + self.mean_function(xnew),
                                    cov.sqrt()).rsample()

            # Update kernel matrix
            N = outside_vars["N"]
            Kffnew = Kff.new_empty(N + 1, N + 1)
            Kffnew[:N, :N] = Kff
            cross = self.kernel(X, xnew).squeeze()
            end = self.kernel(xnew, xnew).squeeze()
            Kffnew[N, :N] = cross
            Kffnew[:N, N] = cross
            # No noise, just jitter for numerical stability
            Kffnew[N, N] = end + self.jitter
            # Heuristic to avoid adding degenerate points
            if Kffnew.logdet() > -15.:
                outside_vars["Kff"] = Kffnew
                outside_vars["N"] += 1
                outside_vars["X"] = torch.cat((X, xnew))
                outside_vars["y"] = torch.cat((y, ynew))

            return ynew

        return lambda xnew: sample_next(xnew, outside_vars)
Exemple #7
0
class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule):
    __constants__ = ['stats_permute_dims', 'collect_stats_steps', 'momentum']

    def __init__(self,
                 collect_stats_steps: int,
                 restrict_scaling_impl: Module,
                 scaling_stats_impl: Module,
                 scaling_shape: Tuple[int, ...],
                 scaling_stats_input_view_shape_impl: Module,
                 scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None,
                 scaling_stats_momentum: float = DEFAULT_MOMENTUM,
                 scaling_min_val: Optional[float] = None) -> None:
        super(ParameterFromRuntimeStatsScaling, self).__init__()
        assert collect_stats_steps > 0, 'Steps should be more than 0'
        if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None:
            raise RuntimeError(
                "Per channel runtime stats require a permute shape")
        self.collect_stats_steps = collect_stats_steps
        self.counter: int = brevitas.jit.Attribute(0, int)
        self.stats_permute_dims = scaling_stats_permute_dims
        self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
        self.stats = _Stats(scaling_stats_impl, scaling_shape)
        self.momentum = scaling_stats_momentum
        self.value = Parameter(torch.full(scaling_shape, 1.0))
        self.restrict_clamp_scaling = _RestrictClampValue(
            scaling_min_val, restrict_scaling_impl)
        self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
        )

    @brevitas.jit.script_method_110_disabled
    def forward(self, stats_input) -> torch.Tensor:
        if self.training:
            if self.counter < self.collect_stats_steps:
                if self.stats_permute_dims is not None:
                    stats_input = stats_input.permute(
                        *self.stats_permute_dims).contiguous()
                stats_input = self.stats_input_view_shape_impl(stats_input)
                stats = self.stats(stats_input)
                if self.counter == 0:
                    self.value.detach().mul_(stats.detach())
                else:
                    self.value.detach().mul_(1 - self.momentum)
                    self.value.detach().add_(self.momentum * stats.detach())
                self.counter = self.counter + 1
                return stats
            elif self.counter == self.collect_stats_steps:
                self.restrict_inplace_preprocess(self.value.detach())
                self.counter = self.counter + 1
                return self.restrict_clamp_scaling(torch.abs(self.value))
            else:
                return self.restrict_clamp_scaling(torch.abs(self.value))
        out = self.restrict_clamp_scaling(torch.abs(self.value))
        return out

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        super(ParameterFromRuntimeStatsScaling,
              self)._load_from_state_dict(state_dict, prefix, local_metadata,
                                          strict, missing_keys,
                                          unexpected_keys, error_msgs)
        value_key = prefix + 'value'
        # Pytorch stores training flag as a buffer with JIT enabled
        training_key = prefix + 'training'
        if training_key in missing_keys:
            missing_keys.remove(training_key)
        if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
            missing_keys.remove(value_key)
Exemple #8
0
class SpatialTransformerPooled3d(Readout):
    def __init__(
            self,
            in_shape,
            outdims,
            pool_steps=1,
            positive=False,
            bias=True,
            init_range=0.05,
            kernel_size=2,
            stride=2,
            grid=None,
            stop_grad=False,
            align_corners=True,
            mean_activity=None,
            feature_reg_weight=1.0,
            gamma_readout=None,  # depricated, use feature_reg_weight instead
    ):
        super().__init__()
        self._pool_steps = pool_steps
        self.in_shape = in_shape
        c, t, w, h = in_shape
        self.outdims = outdims
        self.positive = positive
        self.feature_reg_weight = self.resolve_deprecated_gamma_readout(
            feature_reg_weight, gamma_readout)
        self.mean_activity = mean_activity
        if grid is None:
            self.grid = Parameter(torch.Tensor(1, outdims, 1, 2))
        else:
            self.grid = grid
        self.features = Parameter(
            torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims))
        self.register_buffer("mask", torch.ones_like(self.features))

        if bias:
            bias = Parameter(torch.Tensor(outdims))
            self.register_parameter("bias", bias)
        else:
            self.register_parameter("bias", None)

        self.avg = nn.AvgPool2d(kernel_size,
                                stride=stride,
                                count_include_pad=False)
        self.init_range = init_range
        self.initialize(mean_activity)
        self.stop_grad = stop_grad
        self.align_corners = align_corners

    @property
    def pool_steps(self):
        return self._pool_steps

    @pool_steps.setter
    def pool_steps(self, value):
        assert value >= 0 and int(
            value
        ) - value == 0, "new pool steps must be a non-negative integer"
        if value != self._pool_steps:
            print("Resizing readout features")
            c, t, w, h = self.in_shape
            outdims = self.outdims
            self._pool_steps = int(value)
            self.features = Parameter(
                torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims))
            self.mask = torch.ones_like(self.features)
            self.features.data.fill_(1 / self.in_shape[0])

    def initialize(self, init_noise=1e-3, grid=True, mean_activity=None):
        if mean_activity is None:
            mean_activity = self.mean_activity
        # randomly pick centers within the spatial map
        self.features.data.fill_(1 / self.in_shape[0])
        if grid:
            self.grid.data.uniform_(-self.init_range, self.init_range)
        if self.bias is not None:
            self.initialize_bias(mean_activity=mean_activity)

    def feature_l1(self, reduction="sum", average=None, subs_idx=None):
        subs_idx = subs_idx if subs_idx is not None else slice(None)
        return self.apply_reduction(self.features[..., subs_idx].abs(),
                                    reduction=reduction,
                                    average=average)

    def regularizer(self, reduction="sum", average=None):
        return self.feature_l1(reduction=reduction,
                               average=average) * self.feature_reg_weight

    def reset_fisher_prune_scores(self):
        self._prune_n = 0
        self._prune_scores = self.features.detach() * 0

    def update_fisher_prune_scores(self):
        self._prune_n += 1
        if self.features.grad is None:
            raise ValueError("You need to run backward first")
        self._prune_scores += (0.5 * self.features.grad.pow(2) *
                               self.features.pow(2)).detach()

    @property
    def fisher_prune_scores(self):
        return self._prune_scores / self._prune_n

    def prune(self):
        idx = (self.fisher_prune_scores + 1e6 *
               (1 - self.mask)).squeeze().argmin(dim=0)
        nt = idx.new
        seq = nt(np.arange(len(idx)))
        self.mask[:, idx, :, seq] = 0
        self.features.data[:, idx, :, seq] = 0

    def forward(self, x, shift=None, subs_idx=None):
        if self.stop_grad:
            x = x.detach()

        self.features.data *= self.mask

        if self.positive:
            self.features.data.clamp_min_(0)
        self.grid.data = torch.clamp(self.grid.data, -1, 1)

        N, c, t, w, h = x.size()
        m = self._pool_steps + 1
        if subs_idx is not None:
            feat = self.features[..., subs_idx].contiguous()
            outdims = feat.size(-1)
            feat = feat.view(1, m * c, outdims)
            grid = self.grid[:, subs_idx, ...]
        else:
            grid = self.grid
            feat = self.features.view(1, m * c, self.outdims)
            outdims = self.outdims

        if shift is None:
            grid = grid.expand(N * t, outdims, 1, 2)
        else:
            grid = grid.expand(N, outdims, 1, 2)
            grid = torch.stack(
                [grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1)
            grid = grid.contiguous().view(-1, outdims, 1, 2)
        z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h)
        pools = [F.grid_sample(z, grid, align_corners=self.align_corners)]
        for i in range(self._pool_steps):
            z = self.avg(z)
            pools.append(
                F.grid_sample(z, grid, align_corners=self.align_corners))
        y = torch.cat(pools, dim=1)
        y = (y.squeeze(-1) * feat).sum(1).view(N, t, outdims)

        if self.bias is not None:
            if subs_idx is None:
                y = y + self.bias
            else:
                y = y + self.bias[subs_idx]

        return y

    def __repr__(self):
        c, _, w, h = self.in_shape
        r = self.__class__.__name__ + " (" + "{} x {} x {}".format(
            c, w, h) + " -> " + str(self.outdims) + ")"
        if self.bias is not None:
            r += " with bias"
        if self.stop_grad:
            r += ", stop_grad=True"
        r += "\n"

        for ch in self.children():
            r += "  -> " + ch.__repr__() + "\n"
        return r
Exemple #9
0
class ResNet(nn.Module):
    """ResNet Variants

    Parameters
    ----------
    block : Block
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
    layers : list of int
        Numbers of layers in each block
    classes : int, default 1000
        Number of classification classes.
    dilated : bool, default False
        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
        typically used in Semantic Segmentation.
    norm_layer : object
        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
        for Synchronized Cross-GPU BachNormalization).

    Reference:

        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
    """

    # pylint: disable=unused-variable
    def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, num_classes=1000,
                 dilated=False, dilation=1, deep_stem=False, stem_width=64, avg_down=False,
                 rectified_conv=False, rectify_avg=False, avd=False, avd_first=False,
                 final_drop=0.0, dropblock_prob=0, last_gamma=False, use_se=False, in_channels=300,
                 word_file='/workspace/Projects/cxr/models/feature_extraction/diseases_embeddings.npy',
                 # word_file='diseases_embeddings.npy',
                 # word_file='/home/hoangvu/Projects/cxr/models/feature_extraction/diseases_embeddings.npy',
                 extract_fields='0,1,2,3,4,5', agree_rate=0.5, csv_path='',
                 norm_layer=nn.BatchNorm2d):
        self.cardinality = groups
        self.bottleneck_width = bottleneck_width
        # ResNet-D params
        self.inplanes = stem_width * 2 if deep_stem else 64
        self.avg_down = avg_down
        self.last_gamma = last_gamma
        # ResNeSt params
        self.radix = radix
        self.avd = avd
        self.avd_first = avd_first
        self.use_se = use_se

        super(ResNet, self).__init__()
        self.rectified_conv = rectified_conv
        self.rectify_avg = rectify_avg
        if rectified_conv:
            from rfconv import RFConv2d
            conv_layer = RFConv2d
        else:
            conv_layer = nn.Conv2d
        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
        if deep_stem:
            self.conv1 = nn.Sequential(
                conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False,
                           **conv_kwargs), norm_layer(stem_width), nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False,
                           **conv_kwargs), norm_layer(stem_width), nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1,
                           bias=False, **conv_kwargs), )
        else:
            self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, bias=False,
                                    **conv_kwargs)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        if dilated or dilation == 4:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
                                           norm_layer=norm_layer, dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
                                           norm_layer=norm_layer, dropblock_prob=dropblock_prob)
        elif dilation == 2:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilation=1,
                                           norm_layer=norm_layer, dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2,
                                           norm_layer=norm_layer, dropblock_prob=dropblock_prob)
        else:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, norm_layer):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        num_classes = len(extract_fields.split(','))
        _adj = gen_adj_num(labels=extract_fields, agree_rate=agree_rate, csv_path=csv_path)
        self.adj = Parameter(torch.from_numpy(_adj).float())

        if not os.path.exists(word_file):
            word = np.random.randn(num_classes, 300)
            print('graph input: random')
        else:
            with open(word_file, 'rb') as point:
                word = np.load(point)
                print('graph input: loaded from {}'.format(word_file))
        self.word = Parameter(torch.from_numpy(word).float())

        self.gc0 = GraphConvolution(in_channels, 128, bias=True)
        self.gc1 = GraphConvolution(128, 256, bias=True)
        self.gc2 = GraphConvolution(256, 512, bias=True)
        self.gc3 = GraphConvolution(512, 1024, bias=True)
        self.gc4 = GraphConvolution(1024, 2048, bias=True)
        self.gc_relu = nn.LeakyReLU(0.2)
        self.gc_tanh = nn.Tanh()
        self.merge_conv0 = nn.Conv2d(num_classes, 128, kernel_size=1, stride=1, bias=False)
        self.merge_conv1 = nn.Conv2d(num_classes, 256, kernel_size=1, stride=1, bias=False)
        self.merge_conv2 = nn.Conv2d(num_classes, 512, kernel_size=1, stride=1, bias=False)
        self.merge_conv3 = nn.Conv2d(num_classes, 1024, kernel_size=1, stride=1, bias=False)
        self.conv1x1 = conv1x1(in_channels=2048, out_channels=num_classes, bias=True)
        # self.spatial_attention = SAModule(2048)
        # self.spatial_attention = SpatialCGNL(2048, 1024)

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
                    dropblock_prob=0.0, is_first=True):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            down_layers = []
            if self.avg_down:
                if dilation == 1:
                    down_layers.append(
                        nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True,
                                     count_include_pad=False))
                else:
                    down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True,
                                                    count_include_pad=False))
                down_layers.append(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
                              bias=False))
            else:
                down_layers.append(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride,
                              bias=False))
            down_layers.append(norm_layer(planes * block.expansion))
            downsample = nn.Sequential(*down_layers)

        layers = []
        if dilation == 1 or dilation == 2:
            layers.append(
                block(self.inplanes, planes, stride, downsample=downsample, radix=self.radix,
                      cardinality=self.cardinality, bottleneck_width=self.bottleneck_width,
                      avd=self.avd, avd_first=self.avd_first, dilation=1, is_first=is_first,
                      rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg,
                      norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                      last_gamma=self.last_gamma, use_se=self.use_se))
        elif dilation == 4:
            layers.append(
                block(self.inplanes, planes, stride, downsample=downsample, radix=self.radix,
                      cardinality=self.cardinality, bottleneck_width=self.bottleneck_width,
                      avd=self.avd, avd_first=self.avd_first, dilation=2, is_first=is_first,
                      rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg,
                      norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                      last_gamma=self.last_gamma, use_se=self.use_se))
        else:
            raise RuntimeError("=> unknown dilation size: {}".format(dilation))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(
                block(self.inplanes, planes, radix=self.radix, cardinality=self.cardinality,
                      bottleneck_width=self.bottleneck_width, avd=self.avd,
                      avd_first=self.avd_first, dilation=dilation,
                      rectified_conv=self.rectified_conv, rectify_avg=self.rectify_avg,
                      norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                      last_gamma=self.last_gamma, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, feature):
        adj = gen_adj(self.adj).detach()
        word = self.word.detach()

        feature = self.conv1(feature)
        feature = self.bn1(feature)
        feature = self.relu(feature)
        feature = self.maxpool(feature)

        x_raw = self.gc0(word, adj)
        x = self.gc_tanh(x_raw)
        feature = merge_gcn_residual(feature, x, self.merge_conv0)

        feature = self.layer1(feature)
        x = self.gc_relu(x_raw)
        x_raw = self.gc1(x, adj)
        x = self.gc_tanh(x_raw)
        feature = merge_gcn_residual(feature, x, self.merge_conv1)

        feature = self.layer2(feature)
        x = self.gc_relu(x_raw)
        x_raw = self.gc2(x, adj)
        x = self.gc_tanh(x_raw)
        feature = merge_gcn_residual(feature, x, self.merge_conv2)

        feature = self.layer3(feature)
        x = self.gc_relu(x_raw)
        x_raw = self.gc3(x, adj)
        x = self.gc_tanh(x_raw)
        feature = merge_gcn_residual(feature, x, self.merge_conv3)

        feature = self.layer4(feature)
        # feature = self.spatial_attention(feature)
        feature_raw = self.global_pool(feature)
        if self.drop is not None:
            feature_raw = self.drop(feature_raw)

        feature = feature_raw.view(feature_raw.size(0), -1)

        x = self.gc_relu(x_raw)
        x = self.gc4(x, adj)
        x = self.gc_tanh(x)
        x = x.transpose(0, 1)
        x = torch.matmul(feature, x)

        y = self.conv1x1(feature_raw)
        y = y.view(y.size(0), -1)
        x = x + y
        return x
Exemple #10
0
class Cifp(nn.Module):
    """ Implement of  (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition)
    """
    def __init__(self, in_features, out_features, scale=64.0, margin=0.35):
        """ Args:
            in_features: size of each input features
            out_features: size of each output features
            scale: norm of input feature
            margin: margin
        """
        super(Cifp, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.kernel = Parameter(torch.Tensor(in_features, out_features))
        nn.init.normal_(self.kernel, std=0.01)

    def forward(self, embeddings, label):
        cos_theta, origin_cos = calc_logits(embeddings, self.kernel)
        cos_theta_, _ = calc_logits(embeddings, self.kernel.detach())

        mask = torch.zeros_like(cos_theta)
        mask.scatter_(1, label.view(-1, 1).long(), 1.0)

        sample_num = embeddings.size(0)
        tmp_cos_theta = cos_theta - 2 * mask
        tmp_cos_theta_ = cos_theta_ - 2 * mask
        target_cos_theta = cos_theta[torch.arange(0, sample_num),
                                     label].view(-1, 1)
        target_cos_theta_ = cos_theta_[torch.arange(0, sample_num),
                                       label].view(-1, 1)

        target_cos_theta_m = target_cos_theta - self.margin

        far = 1 / (self.out_features - 1)
        # far = 1e-4
        topk_mask = torch.greater(tmp_cos_theta, target_cos_theta)
        topk_sum = torch.sum(topk_mask.to(torch.int32))
        dist.all_reduce(topk_sum)
        far_rank = math.ceil(
            far * (sample_num *
                   (self.out_features - 1) * dist.get_world_size() - topk_sum))
        cos_theta_neg_topk = torch.topk(
            (tmp_cos_theta - 2 * topk_mask.to(torch.float32)).flatten(),
            k=far_rank)[0]
        cos_theta_neg_topk = all_gather_tensor(cos_theta_neg_topk.contiguous())
        cos_theta_neg_th = torch.topk(cos_theta_neg_topk, k=far_rank)[0][-1]

        cond = torch.mul(torch.bitwise_not(topk_mask),
                         torch.greater(tmp_cos_theta, cos_theta_neg_th))
        _, cos_theta_neg_topk_index = torch.where(cond)
        cos_theta_neg_topk = torch.mul(cond.to(torch.float32), tmp_cos_theta)
        cos_theta_neg_topk_ = torch.mul(cond.to(torch.float32), tmp_cos_theta_)

        cond = torch.greater(target_cos_theta_m, cos_theta_neg_topk)
        cos_theta_neg_topk = torch.where(cond, cos_theta_neg_topk,
                                         cos_theta_neg_topk_)
        cos_theta_neg_topk = torch.pow(cos_theta_neg_topk, 2)
        times = torch.sum(torch.greater(cos_theta_neg_topk,
                                        0).to(torch.float32),
                          dim=1,
                          keepdim=True)
        times = torch.where(torch.greater(times, 0), times,
                            torch.ones_like(times))
        cos_theta_neg_topk = torch.sum(cos_theta_neg_topk, dim=1,
                                       keepdim=True) / times
        target_cos_theta_m = target_cos_theta_m - (
            1 + target_cos_theta_) * cos_theta_neg_topk

        cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m)
        output = cos_theta * self.scale

        return output, origin_cos * self.scale
class AsymmetricLaplace(Distribution):
    def __init__(self, loc=0., scale=1., asymmetry=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        if not isinstance(asymmetry, torch.Tensor):
            asymmetry = torch.tensor(asymmetry).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        self._asymmetry = utils.softplus_inverse(asymmetry.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)
            self._asymmetry = Parameter(self._asymmetry)

    def log_prob(self, value):
        s = (value - self.loc).sign()
        exponent = -(value -
                     self.loc).abs() * self.scale * self.asymmetry.pow(s)
        coeff = self.scale.log() - (self.asymmetry +
                                    (1 / self.asymmetry)).log()
        return (coeff + exponent).sum(-1)

    def sample(self, batch_size):
        U = Uniform(low=-self.asymmetry,
                    high=(1. / self.asymmetry),
                    learnable=False).sample(batch_size)
        s = U.sign()
        log_term = (1. - U * s * self.asymmetry.pow(s)).log()
        return self.loc - (1. /
                           (self.scale * s * self.asymmetry.pow(s))) * log_term

    def cdf(self, value):
        s = (value - self.loc).sign()
        exponent = -(value -
                     self.loc).abs() * self.scale * self.asymmetry.pow(s)
        exponent = exponent.exp()
        return (value > self.loc).float() - s * self.asymmetry.pow(1 - s) / (
            1 + self.asymmetry.pow(2)) * exponent

    # def icdf(self, value):
    #     return

    def entropy(self):
        return (utils.e * (1 + self.asymmetry.pow(2)) /
                (self.asymmetry * self.scale)).log().sum()

    @property
    def expectation(self):
        return self.loc + ((1 - self.asymmetry.pow(2)) /
                           (self.scale * self.asymmetry))

    @property
    def variance(self):
        return (1 + self.asymmetry.pow(4)) / (self.scale.pow(2) *
                                              self.asymmetry.pow(2))

    @property
    def mode(self):
        return self.loc

    @property
    def median(self):
        return self.loc + (self.asymmetry / self.scale) * (
            (1 + self.asymmetry.pow(2)) / (2 * self.asymmetry.pow(2))).log()

    @property
    def skewness(self):
        return (2 *
                (1 - self.asymmetry.pow(6))) / (1 + self.asymmetry.pow(4)).pow(
                    3. / 2.)

    @property
    def kurtosis(self):
        return (6 *
                (1 + self.asymmetry.pow(8))) / (1 +
                                                self.asymmetry.pow(4)).pow(2)

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def asymmetry(self):
        return softplus(self._asymmetry)

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                'loc': self.loc.item(),
                'scale': self.scale.item(),
                'asymmetry': self.asymmetry.item()
            }
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy(),
            'asymmetry': self.asymmetry.detach().numpy()
        }
Exemple #12
0
class ParameterFromRuntimeMinZeroPoint(brevitas.jit.ScriptModule):
    __constants__ = ['stats_permute_dims',
                     'collect_stats_steps',
                     'momentum']

    def __init__(
            self,
            collect_stats_steps: int,
            int_quant: Module,
            stats_reduce_dim: Optional[int],
            zero_point_shape: Tuple[int, ...],
            zero_point_stats_input_view_shape_impl: Module,
            zero_point_stats_permute_dims: Optional[Tuple[int, ...]] = None,
            zero_point_stats_momentum: Optional[float] = DEFAULT_MOMENTUM) -> None:
        super(ParameterFromRuntimeMinZeroPoint, self).__init__()
        assert collect_stats_steps > 0, 'Steps should be more than 0'
        if zero_point_shape != SCALAR_SHAPE and zero_point_stats_permute_dims is None:
            raise RuntimeError("Per channel runtime stats require a permute shape")
        self.collect_stats_steps = collect_stats_steps
        self.counter: int = brevitas.jit.Attribute(0, int)
        self.stats_permute_dims = zero_point_stats_permute_dims
        self.stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl
        self.momentum = zero_point_stats_momentum
        self.value = Parameter(torch.full(zero_point_shape, 0.0))
        self.register_buffer('buffer', torch.full(zero_point_shape, 0.0))
        self.negative_min_or_zero = NegativeMinOrZero(stats_reduce_dim)
        self.int_quant = int_quant

    @brevitas.jit.script_method
    def training_forward(self, x) -> Tensor:
        if self.counter < self.collect_stats_steps:
            if self.stats_permute_dims is not None:
                x = x.permute(*self.stats_permute_dims).contiguous()
            stats_input = self.stats_input_view_shape_impl(x)
            stats = self.negative_min_or_zero(stats_input)
            new_counter = self.counter + 1
            if self.counter == 0:
                inplace_tensor_add(self.buffer, stats.detach())
            else:
                inplace_momentum_update(
                    self.buffer, stats.detach(), self.momentum, self.counter, new_counter)
            self.counter = new_counter
            # work around find_unusued_parameters=True in DDP
            out = stats + 0. * self.value
        elif self.counter == self.collect_stats_steps:
            inplace_tensor_add(self.value.detach(), self.buffer)
            self.counter = self.counter + 1
            out = self.value
        else:
            out = self.value
        return out

    @brevitas.jit.script_method
    def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
        if self.training:
            out = self.training_forward(x)
        else:
            if self.counter <= self.collect_stats_steps:
                out = self.buffer
            else:
                out = self.value
        out = abs_binary_sign_grad(out)
        min_int = self.int_quant.min_int(bit_width)
        out = self.int_quant.to_int(scale, min_int, bit_width, out)
        return out

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        super(ParameterFromRuntimeMinZeroPoint, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        value_key = prefix + 'value'
        # Pytorch stores training flag as a buffer with JIT enabled
        training_key = prefix + 'training'
        if training_key in missing_keys:
            missing_keys.remove(training_key)
        # disable stats collection when a pretrained value is loaded
        if value_key not in missing_keys:
            self.counter = self.collect_stats_steps + 1
        if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
            missing_keys.remove(value_key)
Exemple #13
0
class Laplace(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        return (-(2. * self.scale).log() -
                ((value - self.loc).abs() / self.scale)).sum(-1)

    def sample(self, batch_size):
        return dists.Laplace(self.loc, self.scale).rsample((batch_size, ))

    def cdf(self, value):
        return 0.5 - 0.5 * (value - self.loc).sign() * (
            -(value - self.loc).abs() / self.scale).expm1()

    def icdf(self, value):
        term = value - 0.5
        return self.loc - self.scale * term.sign() * (-2 * term.abs()).log1p()

    def entropy(self):
        return 1 + (2 * self.scale).log()

    def kl(self, other):
        if isinstance(other, Laplace):
            scale_ratio = self.scale / other.scale
            loc_abs_diff = (self.loc - other.loc).abs()
            t1 = -scale_ratio.log()
            t2 = loc_abs_diff / other.scale
            t3 = scale_ratio * (-loc_abs_diff / self.scale).exp()
            return (t1 + t2 + t3 - 1.).sum()
        return None

    @property
    def expectation(self):
        return self.loc

    @property
    def variance(self):
        return 2 * self.scale.pow(2)

    @property
    def median(self):
        return self.loc

    @property
    def stddev(self):
        return (2**0.5) * self.scale

    @property
    def mode(self):
        return self.loc

    @property
    def skewness(self):
        return torch.tensor(0.).float()

    @property
    def kurtosis(self):
        return torch.tensor(3.).float()

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Exemple #14
0
class StudentT(Distribution):
    def __init__(self, df=1., loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        if not isinstance(df, torch.Tensor):
            df = torch.tensor(df).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        self._df = utils.softplus_inverse(df.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)
            self._df = Parameter(self._df)

    def log_prob(self, value):
        model = dists.StudentT(self.df, self.loc, self.scale)
        return model.log_prob(value).sum(-1)

    def sample(self, batch_size):
        model = dists.StudentT(self.df, self.loc, self.scale)
        return model.rsample((batch_size, ))

    def entropy(self):
        return dists.StudentT(self.df, self.loc, self.scale).entropy()

    @property
    def expectation(self):
        return dists.StudentT(self.df, self.loc, self.scale).mean

    @property
    def mode(self):
        return self.expectation

    @property
    def variance(self):
        return dists.StudentT(self.df, self.loc, self.scale).variance

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def df(self):
        return softplus(self._df)

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                'loc': self.loc.item(),
                'scale': self.scale.item(),
                'df': self.df.item()
            }
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy(),
            'df': self.df.detach().numpy()
        }
class MultiheadAttention_v2(nn.Module):
    """Multi-headed attention.
    See "Attention Is All You Need" for more details.
    """
    def __init__(self,
                 embed_dim,
                 num_heads,
                 kdim=None,
                 vdim=None,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 stable_init_ratio=0.,
                 self_attention=False,
                 encoder_decoder_attention=False):

        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        self.self_attention = self_attention
        self.encoder_decoder_attention = encoder_decoder_attention

        assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
                                                             'value to be of the same size'

        if self.qkv_same_dim:
            self.in_proj_weight = Parameter(
                torch.Tensor(3 * embed_dim,
                             embed_dim))  # embed_dim // num_heads = head_dim
        else:
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_mask = torch.ones(embed_dim,
                                     self.kdim,
                                     dtype=torch.half,
                                     device=torch.device('cuda:0'))
            self.v_mask = torch.ones(embed_dim,
                                     self.vdim,
                                     dtype=torch.half,
                                     device=torch.device('cuda:0'))
        self.q_mask = torch.ones(embed_dim,
                                 embed_dim,
                                 dtype=torch.half,
                                 device=torch.device('cuda:0'))

        if bias:
            self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj_mask = torch.ones(embed_dim,
                                        embed_dim,
                                        dtype=torch.half,
                                        device=torch.device('cuda:0'))

        if add_bias_kv:
            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn
        self.num_copied_heads = int(self.num_heads * stable_init_ratio)
        self.onnx_trace = False

        self.enable_torch_version = False
        if hasattr(F, "multi_head_attention_forward"):
            self.enable_torch_version = True
        else:
            self.enable_torch_version = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def reset_parameters(self, prev_weight=None, copied_heads=None):

        if self.qkv_same_dim:
            nn.init.xavier_uniform_(self.in_proj_weight)
        else:
            nn.init.xavier_uniform_(self.k_proj_weight)
            nn.init.xavier_uniform_(self.v_proj_weight)
            nn.init.xavier_uniform_(self.q_proj_weight)

        nn.init.xavier_uniform_(self.out_proj.weight)

        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

        # if self.num_copied_heads > 0 and prev_weight is not None:

        # if self.qkv_same_dim:
        #     same_dim_heads = np.concatenate((copied_heads, copied_heads + self.embed_dim, \
        #         copied_heads + 2 * self.embed_dim), axis=None)

        #     self.in_proj_weight[same_dim_heads, :].data = prev_weight["in"][same_dim_heads, :].data
        # else:
        #     self.k_proj_weight[copied_heads, :].data = prev_weight["k"][copied_heads, :].data
        #     self.v_proj_weight[copied_heads, :].data = prev_weight["v"][copied_heads, :].data
        #     self.q_proj_weight[copied_heads, :].data = prev_weight["q"][copied_heads, :].data

        # self.out_proj.weight[copied_heads, :].data = prev_weight["out"][copied_heads, :].data

        # if self.qkv_same_dim:
        #     self.q_mask[copied_heads, :] = 0
        #     mask = self.q_mask.repeat(3, 1)
        #     self.in_proj_weight = Parameter(self.in_proj_weight * mask + prev_weight["in"] * (1 - mask))
        # else:
        #     self.k_mask[copied_heads, :], self.v_mask[copied_heads, :] = 0, 0
        #     self.k_proj_weight = Parameter(self.k_proj_weight * self.k_mask + prev_weight["k"]* (1 - self.k_mask))
        #     self.v_proj_weight = Parameter(self.v_proj_weight * self.v_mask + prev_weight["v"] * (1 - self.v_mask))
        #     self.q_proj_weight = Parameter(self.q_proj_weight * self.q_mask + prev_weight["q"] * (1 - self.q_mask))

        # self.out_proj_mask[copied_heads, :] = 0
        # self.out_proj.weight = Parameter(self.out_proj.weight * self.out_proj_mask + prev_weight["out"] * (1 - self.out_proj_mask))

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=False,
                static_kv=False,
                attn_mask=None,
                prev_weight=None):
        """Input shape: Time x Batch x Channel
        Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        copied_heads = None
        if self.num_copied_heads > 0 and prev_weight is not None:
            copied_idx = np.random.choice(self.num_heads,
                                          self.num_copied_heads,
                                          replace=False)
            copied_heads = np.reshape(
                np.array([
                    list(range(s, e))
                    for (s, e) in zip(copied_idx *
                                      self.head_dim, (copied_idx + 1) *
                                      self.head_dim)
                ]), -1)

        self.reset_parameters(prev_weight, copied_heads)

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            if self.qkv_same_dim:
                return F.multi_head_attention_forward(
                    query, key, value, self.embed_dim, self.num_heads,
                    self.in_proj_weight, self.in_proj_bias, self.bias_k,
                    self.bias_v, self.add_zero_attn, self.dropout,
                    self.out_proj.weight, self.out_proj.bias, self.training,
                    key_padding_mask, need_weights, attn_mask)
            else:
                return F.multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    torch.empty([0]),
                    self.in_proj_bias,
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout,
                    self.out_proj.weight,
                    self.out_proj.bias,
                    self.training,
                    key_padding_mask,
                    need_weights,
                    attn_mask,
                    use_separate_proj_weight=True,
                    q_proj_weight=self.q_proj_weight,
                    k_proj_weight=self.k_proj_weight,
                    v_proj_weight=self.v_proj_weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(-1, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)
        # q: [bz * num_heads, tgt_len, head_dim]
        # k, v: [bz * num_heads, src_len, head_dim]

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()).type_as(attn_weights)
            else:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        attn_weights = utils.softmax(
            attn_weights,
            dim=-1,
            onnx_trace=self.onnx_trace,
        ).type_as(attn_weights)

        if self.num_copied_heads > 0 and prev_weight is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights[:, copied_idx, :, :] = prev_weight[
                "nonavg"][:, copied_idx, :, :]
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        attention_weights = {}
        if need_weights:
            # un-averaged attention weight
            attention_weights["nonavg"] = attn_weights.view(
                bsz, self.num_heads, tgt_len, src_len).detach().clone()
            # attention_weights["nonavg"] = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) # [bz, num_heads, tgt_len, src_len]

        attn_weights = F.dropout(attn_weights,
                                 p=self.dropout,
                                 training=self.training)
        attn = torch.bmm(attn_weights, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]

        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            # average attention weights over heads
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attention_weights["avg"] = attn_weights.sum(
                dim=1) / self.num_heads  # [bz, tgt_len, src_len]

            # projection weight
            if self.qkv_same_dim:
                attention_weights["in"] = self.in_proj_weight.detach().clone()
            else:
                attention_weights["k"], attention_weights["q"], attention_weights["v"] =\
                 self.k_proj_weight.detach().clone(), self.q_proj_weight.detach().clone(), self.v_proj_weight.detach().clone()
            attention_weights["out"] = self.out_proj.weight.detach().clone()

        return attn, attention_weights

    def in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_q(self, query):
        if self.qkv_same_dim:
            return self._in_proj(query, end=self.embed_dim)
        else:
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[:self.embed_dim]
            return F.linear(query, self.q_proj_weight, bias)

    def in_proj_k(self, key):
        if self.qkv_same_dim:
            return self._in_proj(key,
                                 start=self.embed_dim,
                                 end=2 * self.embed_dim)
        else:
            weight = self.k_proj_weight
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[self.embed_dim:2 * self.embed_dim]
            return F.linear(key, weight, bias)

    def in_proj_v(self, value):
        if self.qkv_same_dim:
            return self._in_proj(value, start=2 * self.embed_dim)
        else:
            weight = self.v_proj_weight
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[2 * self.embed_dim:]
            return F.linear(value, weight, bias)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def reorder_incremental_state(self, incremental_state, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        input_buffer = self._get_input_buffer(incremental_state)
        if input_buffer is not None:
            for k in input_buffer.keys():
                input_buffer[k] = input_buffer[k].index_select(0, new_order)
            self._set_input_buffer(incremental_state, input_buffer)

    def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(
            self,
            incremental_state,
            'attn_state',
        ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        utils.set_incremental_state(
            self,
            incremental_state,
            'attn_state',
            buffer,
        )

    def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
        return attn_weights