def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
        if len(i_it_poly) == 0:
            return torch.zeros([0, 4, 2]).to(i_it_poly)

        h, w = cnn_feature.size(2), cnn_feature.size(3)
        init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly,
                                                       ind, h, w)
        center = (torch.min(i_it_poly, dim=1)[0] +
                  torch.max(i_it_poly, dim=1)[0]) * 0.5
        ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:,
                                                                         None],
                                                     ind, h, w)
        init_feature = torch.cat(
            [init_feature, ct_feature.expand_as(init_feature)], dim=1)
        init_feature = self.fuse(init_feature)

        init_input = torch.cat(
            [init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
        adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num,
                                          init_input.size(2),
                                          init_input.device)
        i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1)
        i_poly = i_poly[:, ::snake_config.init_poly_num // 4]

        return i_poly
Exemple #2
0
    def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
        if len(i_it_poly) == 0:
            return torch.zeros([0, 4, 2]).to(i_it_poly)

        h, w = cnn_feature.size(2), cnn_feature.size(3)
        init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w)

        #左上角+右上角/2 得到中心坐标
        #i_it_poly.shape=(num_batch,points,2)
        #torch.min(i_it_poly, dim=1)[0]->values
        #torch.min(i_it_poly, dim=1)[0].shape=(num_batch,2)
        #torch.min(i_it_poly, dim=1)[1]->indices
        center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5
        ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h, w)


        #让每一个轮廓上的点既包含自身特征也包含图形中心特征
        init_feature = torch.cat([init_feature, ct_feature.expand_as(init_feature)], dim=1)
        init_feature = self.fuse(init_feature)

        init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
        adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)

        #算出偏移量加上到输入调整
        i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1)
        #隔::snake_config.init_poly_num//4个原素取值
        i_poly = i_poly[:, ::snake_config.init_poly_num//4]

        return i_poly
Exemple #3
0
    def loc_cls_head(self,
                     cnn_feature,
                     loc_snake_net,
                     cls_pred_net,
                     i_it_poly,
                     c_it_poly,
                     ind,
                     batch=None):
        if cfg.poly_cls_branch and (len(i_it_poly) == 0):
            return torch.empty((0, 2)).cuda(), torch.zeros_like(
                i_it_poly), torch.zeros_like(i_it_poly)
        if len(i_it_poly) == 0:
            return torch.zeros_like(i_it_poly), torch.zeros_like(i_it_poly)

        h, w = cnn_feature.size(2), cnn_feature.size(3)
        init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly,
                                                       ind, h, w)

        if cfg.bpoint_feat_enhance == 'aster':
            from lib.utils.snake import snake_text_utils
            center = torch.mean(i_it_poly, dim=1)
            ct_feature = snake_gcn_utils.get_gcn_feature(
                cnn_feature, center[:, None], ind, h, w)
            edge_feat = init_feature - ct_feature.expand_as(init_feature)
            init_feature = torch.cat(
                [init_feature,
                 ct_feature.expand_as(init_feature), edge_feat],
                dim=1)
            init_feature = self.fuse(init_feature)
        else:
            if cfg.evolve_ct_feat:
                center = torch.mean(i_it_poly, dim=1)
                ct_feature = snake_gcn_utils.get_gcn_feature(
                    cnn_feature, center[:, None], ind, h, w)
                init_feature = torch.cat(
                    [init_feature,
                     ct_feature.expand_as(init_feature)], dim=1)
                init_feature = self.fuse(init_feature)
            else:
                pass

        c_it_poly = c_it_poly * snake_config.ro
        init_input = torch.cat(
            [init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
        adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num,
                                          init_input.size(2),
                                          init_input.device)

        loc_offsets = loc_snake_net(init_input, adj).permute(0, 2, 1)
        i_poly = i_it_poly * snake_config.ro + loc_offsets

        if cfg.poly_cls_branch:
            poly_cls = cls_pred_net(cnn_feature, i_it_poly, ind, batch)
        else:
            poly_cls = None

        return poly_cls, i_poly, init_feature
Exemple #4
0
 def evolve_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
     if len(i_it_poly) == 0:
         return torch.zeros_like(i_it_poly)
     h, w = cnn_feature.size(2), cnn_feature.size(3)
     init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w)
     c_it_poly = c_it_poly * snake_config.ro
     init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
     adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)
     i_poly = i_it_poly * snake_config.ro + snake(init_input, adj).permute(0, 2, 1)
     return i_poly
Exemple #5
0
    def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
        if len(i_it_poly) == 0:
            i_4poly = torch.zeros([0, 4, 2]).to(i_it_poly)
            i_poly = torch.zeros([0, snake_config.init_poly_num,
                                  2]).to(i_it_poly)
            return i_4poly, i_poly

        if DEBUG:
            print('-----------------')
            print('cnn_feature.shape:', cnn_feature.shape)
            print('i_it_poly.shape:', i_it_poly.shape)

        h, w = cnn_feature.size(2), cnn_feature.size(3)
        init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly,
                                                       ind, h, w)

        center = (torch.min(i_it_poly, dim=1)[0] +
                  torch.max(i_it_poly, dim=1)[0]) * 0.5

        ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:,
                                                                         None],
                                                     ind, h, w)

        if cfg.bpoint_feat_enhance == 'aster':
            edge_feat = init_feature - ct_feature.expand_as(init_feature)
            init_feature = torch.cat(
                [init_feature,
                 ct_feature.expand_as(init_feature), edge_feat],
                dim=1)
            init_feature = self.fuse(init_feature)
        else:
            init_feature = torch.cat(
                [init_feature,
                 ct_feature.expand_as(init_feature)], dim=1)
            init_feature = self.fuse(init_feature)

        init_input = torch.cat(
            [init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
        adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num,
                                          init_input.size(2),
                                          init_input.device)
        if DEBUG:
            print('adj.shape:', adj.shape)
            print('init_input.shape:', init_input.shape)
            print('i_it_poly.shape:', i_it_poly.shape)
        i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1)
        i_4poly = i_poly[:, ::snake_config.init_poly_num // 4]

        if DEBUG:
            print('i_poly.shape:', i_poly.shape)
            print("i_4py.shape:", i_4poly.shape)

        return i_4poly, i_poly
Exemple #6
0
    def evolve_poly(self,
                    snake,
                    cnn_feature,
                    i_it_poly,
                    c_it_poly,
                    ind,
                    ply_cls_flag=False):
        if ply_cls_flag and (len(i_it_poly) == 0):
            return torch.empty((0, 2)).cuda(), torch.zeros_like(
                i_it_poly), torch.zeros_like(i_it_poly)
        if len(i_it_poly) == 0:
            return torch.zeros_like(i_it_poly), torch.zeros_like(i_it_poly)

        h, w = cnn_feature.size(2), cnn_feature.size(3)
        init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly,
                                                       ind, h, w)

        if 0:
            print('-----------------Evolving_poly---------------------')
            print('i_it_poly.shape:', i_it_poly.shape)
            print('init_feature.shape:', init_feature.shape)
            print("c_it_poly.shape:", c_it_poly.shape)
            exit()

        if cfg.bpoint_feat_enhance == 'aster':
            from lib.utils.snake import snake_text_utils
            center = torch.mean(i_it_poly, dim=1)
            ct_feature = snake_gcn_utils.get_gcn_feature(
                cnn_feature, center[:, None], ind, h, w)
            edge_feat = init_feature - ct_feature.expand_as(init_feature)
            init_feature = torch.cat(
                [init_feature,
                 ct_feature.expand_as(init_feature), edge_feat],
                dim=1)
            init_feature = self.fuse(init_feature)
        else:
            if cfg.evolve_ct_feat:
                center = torch.mean(i_it_poly, dim=1)
                ct_feature = snake_gcn_utils.get_gcn_feature(
                    cnn_feature, center[:, None], ind, h, w)
                init_feature = torch.cat(
                    [init_feature,
                     ct_feature.expand_as(init_feature)], dim=1)
                init_feature = self.fuse(init_feature)
            else:
                pass

        c_it_poly = c_it_poly * snake_config.ro
        init_input = torch.cat(
            [init_feature, c_it_poly.permute(0, 2, 1)], dim=1)
        adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num,
                                          init_input.size(2),
                                          init_input.device)

        if ply_cls_flag:
            evolve_polys, evolve_polys_cls = snake(init_input, adj)
            evolve_polys = i_it_poly * snake_config.ro + evolve_polys.permute(
                0, 2, 1)
            return evolve_polys_cls, evolve_polys, init_feature
        else:
            evolve_py = snake(init_input, adj).permute(0, 2, 1)
            i_poly = i_it_poly * snake_config.ro + evolve_py
            return i_poly, init_feature