def inference(self, x, feature, mask): num_points = 512 while mask.shape[-1] != x.shape[-1]: mask = F.interpolate(mask, scale_factor=2, mode="bilinear", align_corners=False) points_idx, points = sampling_points_v2(torch.sigmoid(mask), num_points, training=self.training) coarse = sampling_features(mask, points, align_corners=False) fine = sampling_features(feature, points, align_corners=False) feature_representation = torch.cat([coarse, fine], dim=1) rend = self.mlp(feature_representation) #print(rend.min()) B, C, H, W = mask.shape #print(mask.shape) points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) mask = (mask.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W)) #print(mask.shape) return {"fine": mask}
def forward(self, output, mask): # coarse, stage1, stage2, stage3, stage4, stage5 = output.values() coarse, stage3, stage4, stage5 = output.values() pred0 = F.interpolate(coarse, mask.shape[-2:], mode="bilinear", align_corners=True) # rend1 = stage1[1] # gt_points1 = sampling_features(mask, stage1[0], mode='nearest', align_corners=False).argmax(dim=1) # # print(rend1.shape, gt_points1.shape) # point_loss1 = F.cross_entropy(rend1, gt_points1) # # rend2 = stage2[1] # gt_points2 = sampling_features(mask, stage2[0], mode='nearest', align_corners=False).argmax(dim=1) # point_loss2 = F.cross_entropy(rend2, gt_points2) rend3 = stage3[1] gt_points3 = sampling_features(mask, stage3[0], mode='nearest', align_corners=True).argmax(dim=1) point_loss3 = F.cross_entropy(rend3, gt_points3) rend4 = stage4[1] gt_points4 = sampling_features(mask, stage4[0], mode='nearest', align_corners=True).argmax(dim=1) point_loss4 = F.cross_entropy(rend4, gt_points4) rend5 = stage5[1] gt_points5 = sampling_features(mask, stage5[0], mode='nearest', align_corners=True).argmax(dim=1) point_loss5 = F.cross_entropy(rend5, gt_points5) mask = mask.argmax(dim=1) seg_loss = F.cross_entropy(pred0, mask) # point_loss = point_loss1 + point_loss2 + point_loss3 + point_loss4 + point_loss5 point_loss = point_loss3 + point_loss4 + point_loss5 loss = point_loss + seg_loss return loss
def forward(self, refine, x0, x1, x2, x3, coarse): if not self.training: return self.inference(refine, x0, x1, x2, x3, coarse) # coarse size: 48x48 # rend stage 1 with layer3 temp1 = coarse points1 = sampling_points_v2(torch.sigmoid(temp1), N=512, k=3, beta=0.75) coarse_feature = sampling_features(temp1, points1, align_corners=False) fine_feature = sampling_features(x3, points1, align_corners=False) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend1 = self.mlp3(feature_representation) # coarse size: 48x48 # rend stage 2 with layer2 temp2 = coarse points2 = sampling_points_v2(torch.sigmoid(temp2), N=512, k=3, beta=0.75) coarse_feature = sampling_features(temp2, points2, align_corners=False) fine_feature = sampling_features(x2, points2, align_corners=False) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend2 = self.mlp2(feature_representation) # coarse size: 96x96 # rend stage 3 with layer1 temp3 = F.interpolate(temp2, scale_factor=2, mode='bilinear', align_corners=False) points3 = sampling_points_v2(torch.sigmoid(temp3), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp3, points3, align_corners=False) fine_feature = sampling_features(x1, points3, align_corners=False) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend3 = self.mlp1(feature_representation) # coarse size: 192x192 # rend stage 4 with layer0 temp4 = F.interpolate(temp3, scale_factor=2, mode='bilinear', align_corners=False) points4 = sampling_points_v2(torch.sigmoid(temp4), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp4, points4, align_corners=False) fine_feature = sampling_features(x0, points4, align_corners=False) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend4 = self.mlp0(feature_representation) # coarse size: 384x384 # rend stage 5 with layer refined temp5 = F.interpolate(temp4, scale_factor=2, mode='bilinear', align_corners=False) points5 = sampling_points_v2(torch.sigmoid(temp5), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp5, points5, align_corners=False) fine_feature = sampling_features(refine, points5, align_corners=False) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend5 = self.mlp_refine(feature_representation) return { "coarse": coarse, "stage1": [points1, rend1], "stage2": [points2, rend2], "stage3": [points3, rend3], "stage4": [points4, rend4], "stage5": [points5, rend5], }
def forward(self, x, feature, mask): if not self.training: return self.inference(x, feature, mask) num_points = 2048 points = sampling_points_v2(torch.sigmoid(mask), num_points, self.k, self.beta) coarse = sampling_features(mask, points, align_corners=False) fine = sampling_features(feature, points, align_corners=False) feature_representation = torch.cat([coarse, fine], dim=1) rend = self.mlp(feature_representation) return {"rend": rend, "points": points, "coarse": mask}
def forward(self, output, mask): pred = F.interpolate(output['coarse'], mask.shape[-2:], mode="bilinear", align_corners=True) gt_points = sampling_features(mask, output['points'], mode='bilinear', align_corners=True).argmax(dim=1) mask = mask.argmax(dim=1) seg_loss = F.cross_entropy(pred, mask) point_loss = F.cross_entropy(output['rend'], gt_points) loss = seg_loss + point_loss return loss
def forward(self, output, mask): coarse, stage1, stage2, stage3, stage4, stage5 = output.values() # coarse, stage3, stage4, stage5 = output.values() pred0 = F.interpolate(coarse, mask.shape[-2:], mode="bilinear", align_corners=False) seg_loss = F.binary_cross_entropy_with_logits(pred0, mask) rend1 = stage1[1] gt_points1 = sampling_features(mask, stage1[0], mode='nearest') point_loss1 = F.binary_cross_entropy_with_logits(rend1, gt_points1) # point_loss1 = self.loss(torch.sigmoid(rend1), gt_points1) rend2 = stage2[1] gt_points2 = sampling_features(mask, stage2[0], mode='nearest') point_loss2 = F.binary_cross_entropy_with_logits(rend2, gt_points2) # point_loss2 = self.loss(torch.sigmoid(rend2), gt_points2) rend3 = stage3[1] gt_points3 = sampling_features(mask, stage3[0], mode='nearest') point_loss3 = F.binary_cross_entropy_with_logits(rend3, gt_points3) # point_loss3 = self.loss(torch.sigmoid(rend3), gt_points3) rend4 = stage4[1] gt_points4 = sampling_features(mask, stage4[0], mode='nearest') point_loss4 = F.binary_cross_entropy_with_logits(rend4, gt_points4) # point_loss4 = self.loss(torch.sigmoid(rend4), gt_points4) rend5 = stage5[1] gt_points5 = sampling_features(mask, stage5[0], mode='nearest') point_loss5 = F.binary_cross_entropy_with_logits(rend5, gt_points5) # point_loss5 = self.loss(torch.sigmoid(rend5), gt_points5) # point_loss = point_loss1 + point_loss2 + point_loss3 + point_loss4 + point_loss5 point_loss = point_loss3 + point_loss4 + point_loss5 loss = seg_loss + point_loss return loss
def forward(self, output, mask): pred = torch.sigmoid( F.upsample(output['coarse'], mask.shape[-2:], mode="bilinear", align_corners=True)) gt_points = sampling_features(mask, output['points'], mode='nearest') N = mask.size(0) smooth = 1 input_flat = pred.view(N, -1) target_flat = mask.view(N, -1) intersection = input_flat * target_flat seg_loss = 2 * (intersection.sum(1) + smooth) / ( input_flat.sum(1) + target_flat.sum(1) + smooth) seg_loss = 1 - seg_loss.sum() / N point_loss = F.binary_cross_entropy(torch.sigmoid(output['rend']), gt_points) loss = seg_loss + point_loss return loss
def forward(self, output, mask): pred = F.interpolate(output['coarse'], mask.shape[-2:], mode="bilinear", align_corners=False) gt_points = sampling_features(mask, output['points'], mode='nearest') # N = mask.size(0) # smooth = 1 # input_flat = pred.view(N, -1) # target_flat = mask.view(N, -1) # intersection = input_flat * target_flat # seg_loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) # seg_loss = 1 - seg_loss.sum() / N seg_loss = F.binary_cross_entropy_with_logits(pred, mask) point_loss = F.binary_cross_entropy_with_logits( output['rend'], gt_points) loss = seg_loss + point_loss return loss
def inference(self, refine, x0, x1, x2, x3, coarse): # stage 1 # coarse size: 48x48 # temp = coarse # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training) # coarse_feature = sampling_features(temp, points, align_corners=False) # fine_feature = sampling_features(x3, points, align_corners=False) # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) # rend = self.mlp3(feature_representation) # B, C, H, W = coarse.shape # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) # coarse1 = (coarse.reshape(B, C, -1) # .scatter_(2, points_idx, rend) # .view(B, C, H, W)) # stage 2 # 48x48 # temp = coarse1 # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training) # coarse_feature = sampling_features(temp, points, align_corners=True) # fine_feature = sampling_features(x2, points, align_corners=True) # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) # rend = self.mlp2(feature_representation) # B, C, H, W = coarse1.shape # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) # coarse2 = (coarse1.reshape(B, C, -1) # .scatter_(2, points_idx, rend) # .view(B, C, H, W)) # stage 3 # 96x96 coarse3 = F.interpolate(coarse, scale_factor=2, mode='bilinear', align_corners=True) temp = coarse3 points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training) coarse_feature = sampling_features(temp, points, align_corners=True) fine_feature = sampling_features(x1, points, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend = self.mlp1(feature_representation) B, C, H, W = coarse3.shape points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) coarse3 = (coarse3.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W)) # stage 4 # 192x192 coarse4 = F.interpolate(coarse3, scale_factor=2, mode='bilinear', align_corners=True) temp = coarse4 points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training) coarse_feature = sampling_features(temp, points, align_corners=True) fine_feature = sampling_features(x0, points, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend = self.mlp0(feature_representation) B, C, H, W = coarse4.shape points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) coarse4 = (coarse4.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W)) # stage 5 # 384x384 coarse5 = F.interpolate(coarse4, scale_factor=2, mode='bilinear', align_corners=True) temp = coarse5 points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training) coarse_feature = sampling_features(temp, points, align_corners=True) fine_feature = sampling_features(refine, points, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend = self.mlp_refine(feature_representation) B, C, H, W = coarse5.shape points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) coarse5 = (coarse5.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W)) return {"fine": coarse5}
def forward(self, refine, x0, x1, x2, x3, coarse): if not self.training: return self.inference(refine, x0, x1, x2, x3, coarse) # coarse size: 48x48 # rend stage 1 with layer3 # temp1 = coarse # # print("temp1 value: ", temp1.max(), temp1.min(), temp1.shape) # points1 = sampling_points_v2(torch.softmax(temp1, dim=1), N=512, k=3, beta=0.75) # coarse_feature = sampling_features(temp1, points1, align_corners=False) # fine_feature = sampling_features(x3, points1, align_corners=False) # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) # rend1 = self.mlp3(feature_representation) # coarse size: 48x48 # rend stage 2 with layer2 # temp2 = coarse # # print("temp2 value: ", temp2.max(), temp2.min(), temp2.shape) # points2 = sampling_points_v2(torch.softmax(temp2, dim=1), N=512, k=3, beta=0.75) # coarse_feature = sampling_features(temp2, points2, align_corners=False) # fine_feature = sampling_features(x2, points2, align_corners=False) # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) # rend2 = self.mlp2(feature_representation) # coarse size: 96x96 # rend stage 3 with layer1 temp3 = F.interpolate(coarse, scale_factor=2, mode='bilinear', align_corners=True) # print("temp3 value: ", temp3.max(), temp3.min(), temp3.shape) points3 = sampling_points_v2(torch.softmax(temp3, dim=1), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp3, points3, align_corners=True) fine_feature = sampling_features(x1, points3, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend3 = self.mlp1(feature_representation) # coarse size: 192x192 # rend stage 4 with layer0 temp4 = F.interpolate(temp3, scale_factor=2, mode='bilinear', align_corners=True) # print("temp4 value: ", temp4.max(), temp4.min(), temp4.shape) points4 = sampling_points_v2(torch.softmax(temp4, dim=1), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp4, points4, align_corners=True) fine_feature = sampling_features(x0, points4, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend4 = self.mlp0(feature_representation) # coarse size: 384x384 # rend stage 5 with layer refined temp5 = F.interpolate(temp4, scale_factor=2, mode='bilinear', align_corners=True) # print("temp5 value: ", temp5.max(), temp5.min(), temp5.shape) points5 = sampling_points_v2(torch.softmax(temp5, dim=1), N=2048, k=3, beta=0.75) coarse_feature = sampling_features(temp5, points5, align_corners=True) fine_feature = sampling_features(refine, points5, align_corners=True) feature_representation = torch.cat([coarse_feature, fine_feature], dim=1) rend5 = self.mlp_refine(feature_representation) return { "coarse": coarse, # "stage1": [points1, rend1], # "stage2": [points2, rend2], "stage3": [points3, rend3], "stage4": [points4, rend4], "stage5": [points5, rend5], }
self.rend = RendNet(n_class=n_class) def forward(self, x): refine, x0, x1, x2, x3, coarse = self.seg(x) res = self.rend(refine, x0, x1, x2, x3, coarse) return res def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.xavier_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() if __name__ == "__main__": model = RendUNet() img = torch.rand(4, 3, 512, 512) mask = torch.rand(4, 3, 512, 512) model.train() print('# parameters:', sum(param.numel() for param in model.parameters())) res = model(img) for k, v in res.items(): if k == "coarse": print(k, v.shape) else: print(k, v[0].shape, v[1].shape, sampling_features(mask, v[0]).shape)
class RendDANet(BaseNet): def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d): super(RendDANet, self).__init__(nclass, backbone, norm_layer=norm_layer) self.head = DANetHead(2048, 512, norm_layer=norm_layer) self.seg1 = nn.Sequential(nn.Dropout(0.1), nn.Conv2d(512, nclass, 1)) self.rend_head = PointHead(in_c=527, num_classes=nclass) def forward(self, x): _, c2, _, c4 = self.base_forward(x) mask = self.seg1(self.head(c4)) result = self.rend_head(x, c2, mask) return result if __name__ == "__main__": net = RendDANet(backbone='resnet101', nclass=15) img = torch.rand(4, 3, 384, 384) mask = torch.rand(4, 8, 384, 384) net.train() output = net(img) for k, v in output.items(): print(k, v.shape) test = sampling_features(mask, output['points'], mode='nearest') print(test.shape)