def __init__(self, recurrence=2, **kwargs): if CrissCrossAttention is None: raise RuntimeError('Please install mmcv-full for ' 'CrissCrossAttention ops') super(CCHead, self).__init__(num_convs=2, **kwargs) self.recurrence = recurrence self.cca = CrissCrossAttention(self.channels)
def __init__(self, cfg, **kwargs): super(CCNet, self).__init__(cfg, **kwargs) align_corners, norm_cfg, act_cfg = self.align_corners, self.norm_cfg, self.act_cfg # build criss-cross attention cca_cfg = cfg['cca'] self.conv_before_cca = nn.Sequential( nn.Conv2d(cca_cfg['in_channels'], cca_cfg['out_channels'], kernel_size=3, stride=1, padding=1, bias=False), BuildNormalization(norm_cfg['type'], (cca_cfg['out_channels'], norm_cfg['opts'])), BuildActivation(act_cfg['type'], **act_cfg['opts']), ) self.cca = CrissCrossAttention(cca_cfg['out_channels']) self.conv_after_cca = nn.Sequential( nn.Conv2d(cca_cfg['out_channels'], cca_cfg['out_channels'], kernel_size=3, stride=1, padding=1, bias=False), BuildNormalization(norm_cfg['type'], (cca_cfg['out_channels'], norm_cfg['opts'])), BuildActivation(act_cfg['type'], **act_cfg['opts']), ) # build decoder decoder_cfg = cfg['decoder'] self.decoder = nn.Sequential( nn.Conv2d(decoder_cfg['in_channels'], decoder_cfg['out_channels'], kernel_size=3, stride=1, padding=1, bias=False), BuildNormalization( norm_cfg['type'], (decoder_cfg['out_channels'], norm_cfg['opts'])), BuildActivation(act_cfg['type'], **act_cfg['opts']), nn.Dropout2d(decoder_cfg['dropout']), nn.Conv2d(decoder_cfg['out_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0)) # build auxiliary decoder self.setauxiliarydecoder(cfg['auxiliary']) # freeze normalization layer if necessary if cfg.get('is_freeze_norm', False): self.freezenormalization()
def test_cc_attention(self): if not torch.cuda.is_available(): return from mmcv.ops import CrissCrossAttention loss_func = Loss() input = np.fromfile('tests/data/for_ccattention/ccattention_input.bin', dtype=np.float32) output = np.fromfile( 'tests/data/for_ccattention/ccattention_output.bin', dtype=np.float32) input = input.reshape((1, 32, 45, 45)) output = output.reshape((1, 32, 45, 45)) label = torch.ones((1, 32, 45, 45)) input = torch.FloatTensor(input) output = torch.FloatTensor(output) input.requires_grad = True shape = input.shape channel = shape[1] cca = CrissCrossAttention(channel) cca.cuda() input = input.cuda() label = label.cuda() cca.train() test_output = cca(input) test_loss = loss_func(test_output, label) test_loss.backward() test_output = test_output.detach().cpu().numpy() output = output.numpy() assert np.allclose(test_output, output) assert test_output.shape == shape