def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind): if len(i_it_poly) == 0: return torch.zeros([0, 4, 2]).to(i_it_poly) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5 ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h, w) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature)], dim=1) init_feature = self.fuse(init_feature) init_input = torch.cat( [init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1) i_poly = i_poly[:, ::snake_config.init_poly_num // 4] return i_poly
def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind): if len(i_it_poly) == 0: return torch.zeros([0, 4, 2]).to(i_it_poly) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) #左上角+右上角/2 得到中心坐标 #i_it_poly.shape=(num_batch,points,2) #torch.min(i_it_poly, dim=1)[0]->values #torch.min(i_it_poly, dim=1)[0].shape=(num_batch,2) #torch.min(i_it_poly, dim=1)[1]->indices center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5 ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h, w) #让每一个轮廓上的点既包含自身特征也包含图形中心特征 init_feature = torch.cat([init_feature, ct_feature.expand_as(init_feature)], dim=1) init_feature = self.fuse(init_feature) init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) #算出偏移量加上到输入调整 i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1) #隔::snake_config.init_poly_num//4个原素取值 i_poly = i_poly[:, ::snake_config.init_poly_num//4] return i_poly
def loc_cls_head(self, cnn_feature, loc_snake_net, cls_pred_net, i_it_poly, c_it_poly, ind, batch=None): if cfg.poly_cls_branch and (len(i_it_poly) == 0): return torch.empty((0, 2)).cuda(), torch.zeros_like( i_it_poly), torch.zeros_like(i_it_poly) if len(i_it_poly) == 0: return torch.zeros_like(i_it_poly), torch.zeros_like(i_it_poly) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) if cfg.bpoint_feat_enhance == 'aster': from lib.utils.snake import snake_text_utils center = torch.mean(i_it_poly, dim=1) ct_feature = snake_gcn_utils.get_gcn_feature( cnn_feature, center[:, None], ind, h, w) edge_feat = init_feature - ct_feature.expand_as(init_feature) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature), edge_feat], dim=1) init_feature = self.fuse(init_feature) else: if cfg.evolve_ct_feat: center = torch.mean(i_it_poly, dim=1) ct_feature = snake_gcn_utils.get_gcn_feature( cnn_feature, center[:, None], ind, h, w) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature)], dim=1) init_feature = self.fuse(init_feature) else: pass c_it_poly = c_it_poly * snake_config.ro init_input = torch.cat( [init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) loc_offsets = loc_snake_net(init_input, adj).permute(0, 2, 1) i_poly = i_it_poly * snake_config.ro + loc_offsets if cfg.poly_cls_branch: poly_cls = cls_pred_net(cnn_feature, i_it_poly, ind, batch) else: poly_cls = None return poly_cls, i_poly, init_feature
def evolve_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind): if len(i_it_poly) == 0: return torch.zeros_like(i_it_poly) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) c_it_poly = c_it_poly * snake_config.ro init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) i_poly = i_it_poly * snake_config.ro + snake(init_input, adj).permute(0, 2, 1) return i_poly
def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind): if len(i_it_poly) == 0: i_4poly = torch.zeros([0, 4, 2]).to(i_it_poly) i_poly = torch.zeros([0, snake_config.init_poly_num, 2]).to(i_it_poly) return i_4poly, i_poly if DEBUG: print('-----------------') print('cnn_feature.shape:', cnn_feature.shape) print('i_it_poly.shape:', i_it_poly.shape) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5 ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h, w) if cfg.bpoint_feat_enhance == 'aster': edge_feat = init_feature - ct_feature.expand_as(init_feature) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature), edge_feat], dim=1) init_feature = self.fuse(init_feature) else: init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature)], dim=1) init_feature = self.fuse(init_feature) init_input = torch.cat( [init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) if DEBUG: print('adj.shape:', adj.shape) print('init_input.shape:', init_input.shape) print('i_it_poly.shape:', i_it_poly.shape) i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1) i_4poly = i_poly[:, ::snake_config.init_poly_num // 4] if DEBUG: print('i_poly.shape:', i_poly.shape) print("i_4py.shape:", i_4poly.shape) return i_4poly, i_poly
def evolve_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind, ply_cls_flag=False): if ply_cls_flag and (len(i_it_poly) == 0): return torch.empty((0, 2)).cuda(), torch.zeros_like( i_it_poly), torch.zeros_like(i_it_poly) if len(i_it_poly) == 0: return torch.zeros_like(i_it_poly), torch.zeros_like(i_it_poly) h, w = cnn_feature.size(2), cnn_feature.size(3) init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) if 0: print('-----------------Evolving_poly---------------------') print('i_it_poly.shape:', i_it_poly.shape) print('init_feature.shape:', init_feature.shape) print("c_it_poly.shape:", c_it_poly.shape) exit() if cfg.bpoint_feat_enhance == 'aster': from lib.utils.snake import snake_text_utils center = torch.mean(i_it_poly, dim=1) ct_feature = snake_gcn_utils.get_gcn_feature( cnn_feature, center[:, None], ind, h, w) edge_feat = init_feature - ct_feature.expand_as(init_feature) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature), edge_feat], dim=1) init_feature = self.fuse(init_feature) else: if cfg.evolve_ct_feat: center = torch.mean(i_it_poly, dim=1) ct_feature = snake_gcn_utils.get_gcn_feature( cnn_feature, center[:, None], ind, h, w) init_feature = torch.cat( [init_feature, ct_feature.expand_as(init_feature)], dim=1) init_feature = self.fuse(init_feature) else: pass c_it_poly = c_it_poly * snake_config.ro init_input = torch.cat( [init_feature, c_it_poly.permute(0, 2, 1)], dim=1) adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device) if ply_cls_flag: evolve_polys, evolve_polys_cls = snake(init_input, adj) evolve_polys = i_it_poly * snake_config.ro + evolve_polys.permute( 0, 2, 1) return evolve_polys_cls, evolve_polys, init_feature else: evolve_py = snake(init_input, adj).permute(0, 2, 1) i_poly = i_it_poly * snake_config.ro + evolve_py return i_poly, init_feature