def double_soft_orthoreg(weights, config): """Extention of the Soft Ortogonality reg, forces the Gram matrix of the weight matrix to be close to identity. Also called DSO. References: * Can We Gain More from Orthogonality Regularizations in Training Deep CNNs? Bansal et al. NeurIPS 2018 :param weights: Learned parameters of shape (n_classes, n_features). :return: A float scalar loss. """ wTw = torch.mm(weights.t(), weights) so_1 = torch.frobenius_norm(wTw - torch.eye(wTw.shape[0]).to(weights.device)) wwT = torch.mm(weights, weights.t()) so_2 = torch.frobenius_norm(wwT - torch.eye(wwT.shape[0]).to(weights.device)) if config["squared"]: so_1 = torch.pow(so_1, 2) so_2 = torch.pow(so_2, 2) return config["factor"] * (so_1 + so_2)
def pod_spatial_loss(old_fmaps, fmaps, normalize=True): ''' a, b: list of [bs, c, w, h] ''' loss = torch.tensor(0.).to(fmaps[0].device) for i, (a, b) in enumerate(zip(old_fmaps, fmaps)): assert a.shape == b.shape, 'Shape error' a = torch.pow(a, 2) b = torch.pow(b, 2) a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w] b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w] a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h] b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h] a = torch.cat([a_h, a_w], dim=-1) b = torch.cat([b_h, b_w], dim=-1) if normalize: a = F.normalize(a, dim=1, p=2) b = F.normalize(b, dim=1, p=2) layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1)) loss += layer_loss return loss / len(fmaps)
def spatial_pyramid_pooling( list_attentions_a, list_attentions_b, levels=[1, 2], pool_type="avg", weight_by_level=True, normalize=True, **kwargs ): loss = torch.tensor(0.).to(list_attentions_a[0].device) for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)): # shape of (b, n, w, h) assert a.shape == b.shape a = torch.pow(a, 2) b = torch.pow(b, 2) for j, level in enumerate(levels): if level > a.shape[2]: raise ValueError( "Level {} is too big for spatial dim ({}, {}).".format( level, a.shape[2], a.shape[3] ) ) kernel_size = level // level if pool_type == "avg": a_pooled = F.avg_pool2d(a, (kernel_size, kernel_size)) b_pooled = F.avg_pool2d(b, (kernel_size, kernel_size)) elif pool_type == "max": a_pooled = F.max_pool2d(a, (kernel_size, kernel_size)) b_pooled = F.max_pool2d(b, (kernel_size, kernel_size)) else: raise ValueError("Invalid pool type {}.".format(pool_type)) a_features = a_pooled.view(a.shape[0], -1) b_features = b_pooled.view(b.shape[0], -1) if normalize: a_features = F.normalize(a_features, dim=-1) b_features = F.normalize(b_features, dim=-1) level_loss = torch.frobenius_norm(a_features - b_features, dim=-1).mean(0) if weight_by_level: # Give less importance for smaller cells. level_loss *= 1 / 2**j loss += level_loss return loss
def perceptual_style_reconstruction(list_attentions_a, list_attentions_b, factor=1.): loss = 0. for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)): bs, c, w, h = a.shape a = a.view(bs, c, w * h) b = b.view(bs, c, w * h) gram_a = torch.bmm(a, a.transpose(2, 1)) / (c * w * h) gram_b = torch.bmm(b, b.transpose(2, 1)) / (c * w * h) layer_loss = torch.frobenius_norm(gram_a - gram_b, dim=(1, 2))**2 loss += layer_loss.mean() return factor * (loss / len(list_attentions_a))
def forward(self, B, Sim): nbits = size(B, 1) loss = torch.frobenius_norm(Sim - 1 / nbits * B.mm(B.t())) return loss
def pod( list_attentions_a, list_attentions_b, collapse_channels="spatial", normalize=True, memory_flags=None, only_old=False, **kwargs ): """Pooled Output Distillation. Reference: * Douillard et al. Small Task Incremental Learning. arXiv 2020. :param list_attentions_a: A list of attention maps, each of shape (b, n, w, h). :param list_attentions_b: A list of attention maps, each of shape (b, n, w, h). :param collapse_channels: How to pool the channels. :param memory_flags: Integer flags denoting exemplars. :param only_old: Only apply loss to exemplars. :return: A float scalar loss. """ assert len(list_attentions_a) == len(list_attentions_b) loss = torch.tensor(0.).to(list_attentions_a[0].device) for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)): # shape of (b, n, w, h) assert a.shape == b.shape, (a.shape, b.shape) if only_old: a = a[memory_flags] b = b[memory_flags] if len(a) == 0: continue a = torch.pow(a, 2) b = torch.pow(b, 2) if collapse_channels == "channels": a = a.sum(dim=1).view(a.shape[0], -1) # shape of (b, w * h) b = b.sum(dim=1).view(b.shape[0], -1) elif collapse_channels == "width": a = a.sum(dim=2).view(a.shape[0], -1) # shape of (b, c * h) b = b.sum(dim=2).view(b.shape[0], -1) elif collapse_channels == "height": a = a.sum(dim=3).view(a.shape[0], -1) # shape of (b, c * w) b = b.sum(dim=3).view(b.shape[0], -1) elif collapse_channels == "gap": a = F.adaptive_avg_pool2d(a, (1, 1))[..., 0, 0] b = F.adaptive_avg_pool2d(b, (1, 1))[..., 0, 0] elif collapse_channels == "spatial": a_h = a.sum(dim=3).view(a.shape[0], -1) b_h = b.sum(dim=3).view(b.shape[0], -1) a_w = a.sum(dim=2).view(a.shape[0], -1) b_w = b.sum(dim=2).view(b.shape[0], -1) a = torch.cat([a_h, a_w], dim=-1) b = torch.cat([b_h, b_w], dim=-1) else: raise ValueError("Unknown method to collapse: {}".format(collapse_channels)) if normalize: a = F.normalize(a, dim=1, p=2) b = F.normalize(b, dim=1, p=2) layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1)) loss += layer_loss return loss / len(list_attentions_a)