def __init__(self, num_classes, block, layers, pretrained=False): super(retina, self).__init__() self.model_path = 'data/pretrained_model/resnet50_caffe.pth' self.pretrained = pretrained self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) if self.pretrained == True: print("Loading pretrained weights from %s" %(self.model_path)) state_dict = torch.load(self.model_path) self.load_state_dict({k:v for k,v in state_dict.items() if k in self.state_dict()}) def set_bn_fix(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: for p in m.parameters(): p.requires_grad=False self.apply(set_bn_fix) if block == BasicBlock: fpn_sizes = [self.layer2[layers[1] - 1].conv2.out_channels, self.layer3[layers[2] - 1].conv2.out_channels, self.layer4[layers[3] - 1].conv2.out_channels] elif block == Bottleneck: fpn_sizes = [self.layer2[layers[1] - 1].conv3.out_channels, self.layer3[layers[2] - 1].conv3.out_channels, self.layer4[layers[3] - 1].conv3.out_channels] else: raise ValueError(f"Block type {block} not understood") self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2]) self.regressionModel = RegressionModel(256) self.classificationModel = ClassificationModel(256, num_classes=num_classes) self.anchors = Anchors() self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.focalLoss = losses.FocalLoss() # weights initialization for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() prior = 0.01 self.classificationModel.output.weight.data.fill_(0) self.classificationModel.output.bias.data.fill_(-math.log((1.0 - prior) / prior)) self.regressionModel.output.weight.data.fill_(0) self.regressionModel.output.bias.data.fill_(0) self.freeze_bn() self.resnet_base = nn.Sequential( self.conv1, self.bn1, self.relu, self.maxpool )
def __init__(self, num_classes, block, layers, n_head=1, attention_type='concat', shot_mode='mean', num_way=2, num_shot=5, pos_encoding=True, pretrained=False): super(ceaa_retinanet, self).__init__() self.model_path = 'data/pretrained_model/resnet50_caffe.pth' self.pretrained = pretrained self.inplanes = 64 self.n_head = n_head self.attention_type = attention_type self.shot_mode = shot_mode self.num_shot = num_shot self.pos_encoding = pos_encoding self.support_im_size = 320 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) if self.pretrained == True: print("Loading pretrained weights from %s" % (self.model_path)) state_dict = torch.load(self.model_path) self.load_state_dict({ k: v for k, v in state_dict.items() if k in self.state_dict() }) def set_bn_fix(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: for p in m.parameters(): p.requires_grad = False self.apply(set_bn_fix) if block == BasicBlock: fpn_sizes = [ self.layer2[layers[1] - 1].conv2.out_channels, self.layer3[layers[2] - 1].conv2.out_channels, self.layer4[layers[3] - 1].conv2.out_channels ] elif block == Bottleneck: fpn_sizes = [ self.layer2[layers[1] - 1].conv3.out_channels, self.layer3[layers[2] - 1].conv3.out_channels, self.layer4[layers[3] - 1].conv3.out_channels ] else: raise ValueError(f"Block type {block} not understood") self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2]) # [512, 1024, 2048] self.fpn_dim = 256 attention_output_dim = 256 if attention_type == 'product' else 512 self.regressionModel = RegressionModel(attention_output_dim) self.classificationModel = ClassificationModel(attention_output_dim, num_classes=num_classes) self.anchors = Anchors([4, 5, 6, 7]) self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.focalLoss = losses.FocalLoss() # weights initialization for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() prior = 0.01 self.classificationModel.output.weight.data.fill_(0) self.classificationModel.output.bias.data.fill_(-math.log( (1.0 - prior) / prior)) self.regressionModel.output.weight.data.fill_(0) self.regressionModel.output.bias.data.fill_(0) self.freeze_bn() self.resnet_base = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool) # querys, keys Q_list = [] K_list = [] self.d_k = 64 for i in range(self.n_head): Q_weight = nn.Linear(self.fpn_dim, self.d_k) K_weight = nn.Linear(self.fpn_dim, self.d_k) init.normal_(Q_weight.weight, std=0.01) init.constant_(Q_weight.bias, 0) init.normal_(K_weight.weight, std=0.01) init.constant_(K_weight.bias, 0) Q_list.append(Q_weight) K_list.append(K_weight) self.pyramid_Q_layers = nn.ModuleList(Q_list) self.pyramid_K_layers = nn.ModuleList(K_list) if self.pos_encoding: pel_4 = PositionalEncoding(d_model=256, max_len=20 * 20) pel_5 = PositionalEncoding(d_model=256, max_len=10 * 10) pel_6 = PositionalEncoding(d_model=256, max_len=5 * 5) pel_7 = PositionalEncoding(d_model=256, max_len=3 * 3) self.pos_encoding_layers = nn.ModuleList([pel_4, pel_5, pel_6, pel_7]) if n_head != 1: self.multihead_layer = nn.Linear(n_head * feature_size, feature_size)
def __init__(self, num_classes, block, layers, attention_type, reduce_dim, beta, num_way=2, num_shot=5, pos_encoding=True, pretrained=False): super(SEPAA_retinanet, self).__init__() self.model_path = 'data/pretrained_model/resnet50_caffe.pth' self.pretrained = pretrained self.inplanes = 64 self.attention_type = attention_type self.num_shot = num_shot self.pos_encoding = pos_encoding self.support_im_size = 320 self.reduce_dim = reduce_dim self.beta = beta self.unary_gamma = 0.1 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) if self.pretrained == True: print("Loading pretrained weights from %s" % (self.model_path)) state_dict = torch.load(self.model_path) self.load_state_dict({ k: v for k, v in state_dict.items() if k in self.state_dict() }) def set_bn_fix(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: for p in m.parameters(): p.requires_grad = False self.apply(set_bn_fix) if block == BasicBlock: fpn_sizes = [ self.layer2[layers[1] - 1].conv2.out_channels, self.layer3[layers[2] - 1].conv2.out_channels, self.layer4[layers[3] - 1].conv2.out_channels ] elif block == Bottleneck: fpn_sizes = [ self.layer2[layers[1] - 1].conv3.out_channels, self.layer3[layers[2] - 1].conv3.out_channels, self.layer4[layers[3] - 1].conv3.out_channels ] else: raise ValueError(f"Block type {block} not understood") attention_output_dim = 256 if self.attention_type == 'product' else 512 if self.attention_type == 'product': self.fpn = PyramidFeatures( fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], feature_size=attention_output_dim) # [512, 1024, 2048] else: self.fpn = PyramidFeatures(fpn_sizes[0] * 2, fpn_sizes[1] * 2, fpn_sizes[2] * 2, feature_size=attention_output_dim) self.regressionModel = RegressionModel(attention_output_dim) self.classificationModel = ClassificationModel(attention_output_dim, num_classes=num_classes) self.anchors = Anchors([4, 5, 6, 7]) self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.focalLoss = losses.FocalLoss() # weights initialization for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() prior = 0.01 self.classificationModel.output.weight.data.fill_(0) self.classificationModel.output.bias.data.fill_(-math.log( (1.0 - prior) / prior)) self.regressionModel.output.weight.data.fill_(0) self.regressionModel.output.bias.data.fill_(0) self.freeze_bn() self.resnet_base = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool) # querys, keys unary_list = [] adapt_q_list = [] adapt_k_list = [] channel_k_list = [] self.fpn_dims = [512, 1024, 2048] for fpn_dim in self.fpn_dims: unary_layer = nn.Linear(fpn_dim, 1) init.normal_(unary_layer.weight, std=0.01) init.constant_(unary_layer.bias, 0) adapt_q_layer = nn.Linear(fpn_dim, reduce_dim) init.normal_(adapt_q_layer.weight, std=0.01) init.constant_(adapt_q_layer.bias, 0) adapt_k_layer = nn.Linear(fpn_dim, reduce_dim) init.normal_(adapt_k_layer.weight, std=0.01) init.constant_(adapt_k_layer.bias, 0) channel_k_layer = nn.Linear(fpn_dim, 1) init.normal_(channel_k_layer.weight, std=0.01) init.constant_(channel_k_layer.bias, 0) unary_list.append(unary_layer) adapt_q_list.append(adapt_q_layer) adapt_k_list.append(adapt_k_layer) channel_k_list.append(channel_k_layer) self.unary_layers = nn.ModuleList(unary_list) self.adapt_Q_layers = nn.ModuleList(adapt_q_list) self.adapt_K_layers = nn.ModuleList(adapt_k_list) self.channel_K_layers = nn.ModuleList(channel_k_list) if self.pos_encoding: pel_3 = PositionalEncoding(d_model=512, max_len=40 * 40) pel_4 = PositionalEncoding(d_model=1024, max_len=20 * 20) pel_5 = PositionalEncoding(d_model=2048, max_len=10 * 10) self.pos_encoding_layers = nn.ModuleList([pel_3, pel_4, pel_5])