Ejemplo n.º 1
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        sqrt_h_outs = backproped["matrices"]
        sqrt_h_outs_signs = backproped["signs"]
        h_diag = torch.zeros_like(module.weight)

        for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs):
            h_diag_curr = LinUtils.extract_weight_diagonal(module, h_sqrt)
            h_diag.add_(sign * h_diag_curr)
        return h_diag
Ejemplo n.º 2
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        sqrt_h_outs = backproped["matrices"]
        sqrt_h_outs_signs = backproped["signs"]
        h_diag = torch.zeros_like(module.weight)

        for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs):
            h_diag.add_(
                LinUtils.extract_weight_diagonal(module,
                                                 h_sqrt,
                                                 sum_batch=True),
                alpha=sign,
            )

        return h_diag
Ejemplo n.º 3
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        N = module.input0.shape[0]
        sqrt_h_outs = backproped["matrices"]
        sqrt_h_outs_signs = backproped["signs"]
        h_diag = torch.zeros(
            N,
            *module.weight.shape,
            device=module.weight.device,
            dtype=module.weight.dtype,
        )

        for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs):
            h_diag.add_(
                LinUtils.extract_weight_diagonal(module,
                                                 h_sqrt,
                                                 sum_batch=False),
                alpha=sign,
            )
        return h_diag
Ejemplo n.º 4
0
 def weight(self, ext, module, grad_inp, grad_out, backproped):
     return LinUtils.extract_weight_diagonal(module, backproped, sum_batch=False)