def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """
        lLabels = ['B', 'O', 'I']

        lIgnoredLabels = None
        """
        if you play with a toy collection, which does not have all expected classes, you can reduce those.
        """

        lActuallySeen = None
        if lActuallySeen:
            print("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING")
            lIgnoredLabels = [
                lLabels[i] for i in range(len(lLabels))
                if i not in lActuallySeen
            ]
            lLabels = [lLabels[i] for i in lActuallySeen]
            print(len(lLabels), lLabels)
            print(len(lIgnoredLabels), lIgnoredLabels)

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = Graph_MultiSinglePageXml
        nt = NodeType_PageXml_type_woText(
            "abp"  #some short prefix because labels below are prefixed with it
            ,
            lLabels,
            lIgnoredLabels,
            False  #no label means OTHER
            ,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        nt.setLabelAttribute("DU_row")

        # ntA = NodeType_PageXml_type_woText("abp"                   #some short prefix because labels below are prefixed with it
        #                       , lLabels
        #                       , lIgnoredLabels
        #                       , False    #no label means OTHER
        #                       )

        nt.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                        )

        # ntA.setXpathExpr( (".//pc:TextLine | .//pc:TextRegion"        #how to find the nodes
        #                   , "./pc:TextEquiv")       #how to get their text
        #                 )
        DU_GRAPH.addNodeType(nt)

        return DU_GRAPH
Exemple #2
0
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """

        # Textline labels
        #  Begin Inside End Single Other
        lLabels_SIO_row = ['S', 'I', 'O']

        # Cut lines:
        #  Border Ignore Separator Outside
        lLabels_SIO_Cut = ['S', 'I', 'O']

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = GraphCut_H

        DU_GRAPH.iBlockVisibility = cls.iBlockVisibility
        DU_GRAPH.iLineVisibility = cls.iLineVisibility

        # ROW
        ntR = NodeType_BIESO_to_SIO(
            "row",
            lLabels_SIO_row,
            None,
            False,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)))
        ntR.setLabelAttribute("DU_row")
        ntR.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                         )
        DU_GRAPH.addNodeType(ntR)

        # HEADER
        ntCutH = NodeType_PageXml_type_woText(
            "sepH",
            lLabels_SIO_Cut,
            None,
            False,
            None  # equiv. to: BBoxDeltaFun=lambda _: 0
        )
        ntCutH.setLabelAttribute("type")
        ntCutH.setXpathExpr((
            './/pc:CutSeparator[@orient="0"]'  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                            )
        DU_GRAPH.addNodeType(ntCutH)

        DU_GRAPH.setClassicNodeTypeList([ntR])

        return DU_GRAPH
Exemple #3
0
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """

        lLabels_COLUMN_HEADER = [
            'CH',
            'D',
            'O',
        ]

        #         """
        #         if you play with a toy collection, which does not have all expected classes, you can reduce those.
        #         """
        #
        #         lActuallySeen = None
        #         if lActuallySeen:
        #             print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING")
        #             lIgnoredLabels  = [lLabelsR[i] for i in range(len(lLabelsR)) if i not in lActuallySeen]
        #             lLabels         = [lLabelsR[i] for i in lActuallySeen ]
        #             print( len(lLabelsR)          , lLabelsR)
        #             print( len(lIgnoredLabels)   , lIgnoredLabels)

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = Graph_MultiContinousPageXml

        # HEADER
        ntH = NodeType_PageXml_type_woText(
            "hdr",
            lLabels_COLUMN_HEADER,
            None,
            False  #no label means OTHER
            ,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        ntH.setLabelAttribute("DU_header")
        ntH.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                         )
        DU_GRAPH.addNodeType(ntH)

        return DU_GRAPH
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """

        lLabelsBIEOS_R = ['B', 'I', 'E', 'S', 'O']  #O?
        #lLabelsSM_C     = ['M', 'S', 'O']   # single cell, multicells
        #         lLabels_OI      = ['O','I']   # inside/outside a table
        #         lLabels_SPAN    = ['rspan','cspan','nospan','OTHER']
        #lLabels_COLUMN_HEADER  = ['CH', 'D', 'O',]

        #         """
        #         if you play with a toy collection, which does not have all expected classes, you can reduce those.
        #         """
        #
        #         lActuallySeen = None
        #         if lActuallySeen:
        #             print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING")
        #             lIgnoredLabels  = [lLabelsR[i] for i in range(len(lLabelsR)) if i not in lActuallySeen]
        #             lLabels         = [lLabelsR[i] for i in lActuallySeen ]
        #             print( len(lLabelsR)          , lLabelsR)
        #             print( len(lIgnoredLabels)   , lIgnoredLabels)

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = Graph_MultiContinousPageXml

        # ROW
        ntR = NodeType_PageXml_type_woText(
            "row",
            lLabelsBIEOS_R,
            None,
            False  #no label means OTHER
            ,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        ntR.setLabelAttribute("DU_row")
        ntR.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                         )
        DU_GRAPH.addNodeType(ntR)

        return DU_GRAPH
Exemple #5
0
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """
        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = Graph_MultiPageXml

        lLabels1 = [
            'heading', 'header', 'page-number', 'resolution-number',
            'resolution-marginalia', 'resolution-paragraph', 'other'
        ]

        #the converter changed to other unlabelled TextRegions or 'marginalia' TRs
        lIgnoredLabels1 = None
        """
        if you play with a toy collection, which does not have all expected classes, you can reduce those.
        """

        #         lActuallySeen = None
        #         if lActuallySeen:
        #             print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING")
        #             lIgnoredLabels  = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen]
        #             lLabels         = [lLabels[i] for i in lActuallySeen ]
        #             print( len(lLabels)          , lLabels)
        #             print( len(lIgnoredLabels)   , lIgnoredLabels)

        nt1 = NodeType_PageXml_type_woText(
            "sem"  #some short prefix because labels below are prefixed with it
            ,
            lLabels1,
            lIgnoredLabels1,
            False  #no label means OTHER
            ,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        nt1.setLabelAttribute("DU_sem")
        nt1.setXpathExpr((
            ".//pc:TextRegion"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                         )
        DU_GRAPH.addNodeType(nt1)

        return DU_GRAPH
    lLabels         = [lLabels[i] for i in lActuallySeen ]
    print (len(lLabels)          , lLabels)
    print (len(lIgnoredLabels)   , lIgnoredLabels)
    nbClass = len(lLabels) + 1  #because the ignored labels will become OTHER

#DEFINING THE CLASS OF GRAPH WE USE
DU_GRAPH = Graph_MultiPageXml
# nt = NodeType_PageXml_type_woText("abp"                   #some short prefix because labels below are prefixed with it
#                       , lLabels
#                       , lIgnoredLabels
#                       , False    #no label means OTHER
#                       )

ntA = NodeType_PageXml_type_woText("abp"                   #some short prefix because labels below are prefixed with it
                      , lLabels
                      , lIgnoredLabels
                      , False    #no label means OTHER
                      )

# nt.setXpathExpr( (".//pc:TextLine"        #how to find the nodes
#                   , "./pc:TextEquiv")       #how to get their text
#                )

ntA.setXpathExpr( (".//pc:TextLine | .//pc:TextRegion | .//pc:SeparatorRegion"        #how to find the nodes
                  , "./pc:TextEquiv")       #how to get their text
                )




# ===============================================================================================================
    def __init__(self,
                 sModelName,
                 sModelDir,
                 sComment=None,
                 C=None,
                 tol=None,
                 njobs=None,
                 max_iter=None,
                 inference_cache=None):

        # ===============================================================================================================

        lLabels = ['RB', 'RI', 'RE', 'RS', 'RO']

        lIgnoredLabels = None

        nbClass = len(lLabels)
        """
        if you play with a toy collection, which does not have all expected classes, you can reduce those.
        """

        lActuallySeen = None
        if lActuallySeen:
            print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING"
            lIgnoredLabels = [
                lLabels[i] for i in range(len(lLabels))
                if i not in lActuallySeen
            ]
            lLabels = [lLabels[i] for i in lActuallySeen]
            print len(lLabels), lLabels
            print len(lIgnoredLabels), lIgnoredLabels
            nbClass = len(
                lLabels) + 1  #because the ignored labels will become OTHER

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = Graph_MultiSinglePageXml
        nt = NodeType_PageXml_type_woText(
            "abp"  #some short prefix because labels below are prefixed with it
            ,
            lLabels,
            lIgnoredLabels,
            False  #no label means OTHER
            ,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        # ntA = NodeType_PageXml_type_woText("abp"                   #some short prefix because labels below are prefixed with it
        #                       , lLabels
        #                       , lIgnoredLabels
        #                       , False    #no label means OTHER
        #                       )

        nt.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                        )

        # ntA.setXpathExpr( (".//pc:TextLine | .//pc:TextRegion"        #how to find the nodes
        #                   , "./pc:TextEquiv")       #how to get their text
        #                 )

        DU_GRAPH.addNodeType(nt)

        # ===============================================================================================================

        DU_CRF_Task.__init__(
            self,
            sModelName,
            sModelDir,
            DU_GRAPH,
            dFeatureConfig={},
            dLearnerConfig={
                'C':
                .1 if C is None else C,
                'njobs':
                8 if njobs is None else njobs,
                'inference_cache':
                50 if inference_cache is None else inference_cache
                #, 'tol'              : .1
                ,
                'tol':
                .05 if tol is None else tol,
                'save_every':
                50  #save every 50 iterations,for warm start
                ,
                'max_iter':
                1000 if max_iter is None else max_iter
            },
            sComment=sComment
            #,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText
            ,
            cFeatureDefinition=FeatureDefinition_PageXml_NoNodeFeat_v3)

        #self.setNbClass(3)     #so that we check if all classes are represented in the training set

        self.bsln_mdl = self.addBaseline_LogisticRegression(
        )  #use a LR model trained by GridSearch as baseline
Exemple #8
0
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """

        # Textline labels
        #  Begin Inside End Single Other
        lLabels_BIESO = ['B', 'I', 'E', 'S', 'O']

        # Grid lines:
        #  Border Ignore Separator Outside
        lLabels_BISO_Grid = ['B', 'I', 'S', 'O']

        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = GraphGrid_H

        DU_GRAPH.iGridStep_H = cls.iGridStep_H
        DU_GRAPH.iGridStep_V = cls.iGridStep_V
        DU_GRAPH.iGridVisibility = cls.iGridVisibility
        DU_GRAPH.iBlockVisibility = cls.iBlockVisibility

        # ROW
        ntR = NodeType_PageXml_type_woText(
            "row",
            lLabels_BIESO,
            None,
            False,
            BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                       )  #we reduce overlap in this way
        )
        ntR.setLabelAttribute("DU_row")
        ntR.setXpathExpr((
            ".//pc:TextLine"  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                         )
        DU_GRAPH.addNodeType(ntR)

        # HEADER
        ntGH = NodeType_PageXml_type_woText(
            "gh",
            lLabels_BISO_Grid,
            None,
            False,
            None  # equiv. to: BBoxDeltaFun=lambda _: 0
        )
        ntGH.setLabelAttribute("type")
        ntGH.setXpathExpr((
            './/pc:GridSeparator[@orient="0"]'  #how to find the nodes
            ,
            "./pc:TextEquiv")  #how to get their text
                          )
        DU_GRAPH.addNodeType(ntGH)

        DU_GRAPH.setClassicNodeTypeList([ntR])

        return DU_GRAPH
Exemple #9
0
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """
        
        # Textline labels
        #  Begin Inside End Single Other
        lLabels_BIESO           = ['B', 'I', 'E', 'S', 'O'] 
        lLabels_COLUMN_HEADER   = ['CH', 'D', 'O',]

        # Cut lines: 
        #  Border Ignore Separator Outside
        lLabels_SIO_Cut  = ['S', 'I', 'O']
       
        #DEFINING THE CLASS OF GRAPH WE USE
        # this is an ad-hoc class where type1 and type2 are factorial, while type3 is artificial object
        DU_GRAPH = AdHocFactorialGraphCut_H
        
        DU_GRAPH.iBlockVisibility   = cls.iBlockVisibility
        DU_GRAPH.iLineVisibility    = cls.iLineVisibility
        
        # ROW
        ntR = NodeType_PageXml_type_woText("row"
                              , lLabels_BIESO
                              , None
                              , False
                              , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))
                              )
        ntR.setLabelAttribute("DU_row")
        ntR.setXpathExpr( (".//pc:TextLine"        #how to find the nodes
                          , "./pc:TextEquiv")       #how to get their text
                       )
        DU_GRAPH.addNodeType(ntR)

        # HEADER
        ntH = NodeType_PageXml_type_woText("hdr"
                              , lLabels_COLUMN_HEADER
                              , None
                              , False    #no label means OTHER
                              , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))  #we reduce overlap in this way
                              )
        ntH.setLabelAttribute("DU_header")
        ntH.setXpathExpr( (".//pc:TextLine"        #how to find the nodes
                          , "./pc:TextEquiv")       #how to get their text
                       )
        DU_GRAPH.addNodeType(ntH) 
                
        # HEADER
        ntCutH = NodeType_PageXml_type_woText("sepH"
                              , lLabels_SIO_Cut
                              , None
                              , False
                              , None        # equiv. to: BBoxDeltaFun=lambda _: 0
                              )
        ntCutH.setLabelAttribute("type")
        ntCutH.setXpathExpr( ('.//pc:CutSeparator[@orient="0"]'        #how to find the nodes
                          , "./pc:TextEquiv")       #how to get their text
                       )
        DU_GRAPH.addNodeType(ntCutH)        
        
        # The nodes of this type (called "classic") are directly extracted from the XML
        # the other types of nodes are computed
        DU_GRAPH.setClassicNodeTypeList([ntR])
        DU_GRAPH.setSpecialNodeTypeList([ntCutH])
        DU_GRAPH.setFactoredClassicalType(ntR, ntH)  # make ntH a factorial of ntR
        
        return DU_GRAPH
Exemple #10
0
class DU_BAR_sgm(DU_CRF_Task):
    """
    We will do a typed CRF model for a DU task
    , with the below labels 
    """
    sLabeledXmlFilenamePattern = "*.du_mpxml"

    # ===============================================================================================================
    #DEFINING THE CLASS OF GRAPH WE USE
    DU_GRAPH = Graph_MultiContinousPageXml

    #lLabels2 = ['heigh', 'ho', 'other']
    #lLabels2 = ['heigh', 'ho']
    lLabels2 = ['B', 'I', 'E']  #we never see any S...  , 'S']

    # Some TextRegion have no segmentation label at all, and were labelled'other' by the converter
    lIgnoredLabels2 = None

    # """
    # if you play with a toy collection, which does not have all expected classes, you can reduce those.
    # """
    #
    # lActuallySeen = None
    # if lActuallySeen:
    #     print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING"
    #     lIgnoredLabels  = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen]
    #     lLabels         = [lLabels[i] for i in lActuallySeen ]
    #     print len(lLabels)          , lLabels
    #     print len(lIgnoredLabels)   , lIgnoredLabels
    #     nbClass = len(lLabels) + 1  #because the ignored labels will become OTHER

    nt2 = NodeType_PageXml_type_woText(
        "sgm"  #some short prefix because labels below are prefixed with it
        ,
        lLabels2,
        lIgnoredLabels2,
        False  #no label means OTHER
        ,
        BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)
                                   )  #we reduce overlap in this way
    )
    nt2.setLabelAttribute("DU_sgm")
    nt2.setXpathExpr((
        ".//pc:TextRegion"  #how to find the nodes
        ,
        "./pc:TextEquiv")  #how to get their text
                     )
    DU_GRAPH.addNodeType(nt2)

    #=== CONFIGURATION ====================================================================
    def __init__(self,
                 sModelName,
                 sModelDir,
                 sComment=None,
                 C=None,
                 tol=None,
                 njobs=None,
                 max_iter=None,
                 inference_cache=None):

        DU_CRF_Task.__init__(
            self,
            sModelName,
            sModelDir,
            self.DU_GRAPH,
            dLearnerConfig={
                'C':
                .1 if C is None else C,
                'njobs':
                8 if njobs is None else njobs,
                'inference_cache':
                50 if inference_cache is None else inference_cache
                #, 'tol'              : .1
                ,
                'tol':
                .05 if tol is None else tol,
                'save_every':
                50  #save every 50 iterations,for warm start
                ,
                'max_iter':
                1000 if max_iter is None else max_iter
            },
            sComment=sComment,
            cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText
            #                     , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText
            #                      , dFeatureConfig = {
            #                          #config for the extractor of nodes of each type
            #                          "text": None,
            #                          "sprtr": None,
            #                          #config for the extractor of edges of each type
            #                          "text_text": None,
            #                          "text_sprtr": None,
            #                          "sprtr_text": None,
            #                          "sprtr_sprtr": None
            #                          }
        )

        traceln("- classes: ", self.DU_GRAPH.getLabelNameList())

        self.bsln_mdl = self.addBaseline_LogisticRegression(
        )  #use a LR model trained by GridSearch as baseline

    #=== END OF CONFIGURATION =============================================================

    def predict(self, lsColDir, sDocId):
        """
        Return the list of produced files
        """
        #         self.sXmlFilenamePattern = "*.a_mpxml"
        return DU_CRF_Task.predict(self, lsColDir, sDocId)
class DU_BAR_sem_sgm(DU_FactorialCRF_Task):
    """
    We will do a Factorial CRF model using the Multitype CRF 
    , with the below labels 
    """
    sLabeledXmlFilenamePattern = "*.du_mpxml"

    # ===============================================================================================================
    #DEFINING THE CLASS OF GRAPH WE USE
    DU_GRAPH = FactorialGraph_MultiContinuousPageXml
    
    #---------------------------------------------
    lLabels1 = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other']
    
    nt1 = NodeType_PageXml_type_woText("sem"                   #some short prefix because labels below are prefixed with it
                          , lLabels1
                          , None                                #keep this to None, unless you know very well what you do. (FactorialCRF!)
                          , False    #no label means OTHER
                          , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))  #we reduce overlap in this way
                          )
    nt1.setLabelAttribute("DU_sem")
    nt1.setXpathExpr( (".//pc:TextRegion"        #how to find the nodes, MUST be same as for other node type!! (FactorialCRF!)
                      , "./pc:TextEquiv")       #how to get their text
                   )
    DU_GRAPH.addNodeType(nt1)

    #---------------------------------------------
    #lLabels2 = ['heigh', 'ho', 'other']
    #lLabels2 = ['heigh', 'ho']
    lLabels2 = ['B', 'I', 'E']  #we never see any S...  , 'S']
    lLabels2 = ['B', 'I', 'E', 'S', 'O']  #we never see any S...  , 'S']
    
    nt2 = NodeType_PageXml_type_woText("sgm"                   #some short prefix because labels below are prefixed with it
                          , lLabels2
                          , None                                #keep this to None, unless you know very well what you do. (FactorialCRF!)
                          , False    #no label means OTHER
                          , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))  #we reduce overlap in this way
                          )
    nt2.setLabelAttribute("DU_sgm")
    nt2.setXpathExpr( (".//pc:TextRegion"        #how to find the nodes, MUST be same as for other node type!! (FactorialCRF!)
                      , "./pc:TextEquiv")       #how to get their text
                   )
    DU_GRAPH.addNodeType(nt2)
    
    #=== CONFIGURATION ====================================================================
    def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): 
        
#         #edge feature extractor config is a bit teddious...
#         dFeatureConfig = { lbl:None for lbl in self.lLabels1+self.lLabels2 }
#         for lbl1 in self.lLabels1:
#             for lbl2 in self.lLabels2:
#                 dFeatureConfig["%s_%s"%(lbl1, lbl2)] = None
        
        DU_FactorialCRF_Task.__init__(self
                     , sModelName, sModelDir
                     , self.DU_GRAPH
                     , dLearnerConfig = {
                                   'C'                : .1   if C               is None else C
                                 , 'njobs'            : 8    if njobs           is None else njobs
                                 , 'inference_cache'  : 50   if inference_cache is None else inference_cache
                                 #, 'tol'              : .1
                                 , 'tol'              : .05  if tol             is None else tol
                                 , 'save_every'       : 50     #save every 50 iterations,for warm start
                                 , 'max_iter'         : 1000 if max_iter        is None else max_iter
                         }
                     , sComment=sComment
                     , cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText
#                     , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText
#                     {
#                         #config for the extractor of nodes of each type
#                         "text": None,    
#                         "sprtr": None,
#                         #config for the extractor of edges of each type
#                         "text_text": None,    
#                         "text_sprtr": None,    
#                         "sprtr_text": None,    
#                         "sprtr_sprtr": None    
#                         }
                     )
        
        traceln("- classes: ", self.DU_GRAPH.getLabelNameList())

        self.bsln_mdl = self.addBaseline_LogisticRegression()    #use a LR model trained by GridSearch as baseline
        
    #=== END OF CONFIGURATION =============================================================

  
    def predict(self, lsColDir,sDocId):
        """
        Return the list of produced files
        """
#         self.sXmlFilenamePattern = "*.a_mpxml"
        return DU_FactorialCRF_Task.predict(self, lsColDir,sDocId)
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """
        
        # Textline labels
        #  Begin Inside End Single Other
        lLabels_BIESO  = ['B', 'I', 'E', 'S', 'O'] 

        # Grid lines: 
        #  Border Ignore Separator Outside
        lLabels_BISO_Grid  = ['B', 'I', 'S', 'O']
       
        #DEFINING THE CLASS OF GRAPH WE USE
        DU_GRAPH = GraphGrid_H
        
        DU_GRAPH.iGridStep_H        = cls.iGridStep_H
        DU_GRAPH.iGridStep_V        = cls.iGridStep_V
        DU_GRAPH.iGridVisibility    = cls.iGridVisibility
        DU_GRAPH.iBlockVisibility   = cls.iBlockVisibility
        
        # ROW
        ntR = NodeType_PageXml_type_woText("row"
                              , lLabels_BIESO
                              , None
                              , False
                              
                              #HISTORICAL FUNCTION IS (idiotic I think...):
                              #, BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))
                              
                              , BBoxDeltaFun=lambda v: v / 5.0,  #keep 2/3rd of the box  
                              # we reduce overlap in this way
                              #this function returns the amount by which each border of
                              # a bounding box is "shifted toward its centre"...
                              #     w,h = x2-x1, y2-y1
                              #     dx = self.BBoxDeltaFun(w)
                              #     dy = self.BBoxDeltaFun(h)
                              #     x1,y1, x2,y2 = [ int(round(v)) for v in [x1+dx,y1+dy, x2-dx,y2-dy] ]

                              )
        ntR.setLabelAttribute("DU_row")
        ntR.setXpathExpr( (".//pc:TextLine"        #how to find the nodes
                          , "./pc:TextEquiv")       #how to get their text
                       )
        DU_GRAPH.addNodeType(ntR)
        
        # HEADER
        ntGH = NodeType_PageXml_type_woText("gh"
                              , lLabels_BISO_Grid
                              , None
                              , False
                              , None        # equiv. to: BBoxDeltaFun=lambda _: 0
                              )
        ntGH.setLabelAttribute("type")
        ntGH.setXpathExpr( ('.//pc:GridSeparator[@orient="0"]'        #how to find the nodes
                          , "./pc:TextEquiv")       #how to get their text
                       )
        DU_GRAPH.addNodeType(ntGH)        
        
        DU_GRAPH.setClassicNodeTypeList( [ntR ])
        
        return DU_GRAPH
Exemple #13
0
class DU_ABPTable_TypedCRF(DU_CRF_Task):
    """
    We will do a typed CRF model for a DU task
    , with the below labels 
    """
    sXmlFilenamePattern = "*.mpxml"
    
    sLabeledXmlFilenamePattern = "*.mpxml"

    sLabeledXmlFilenameEXT = ".mpxml"

    #=== CONFIGURATION ====================================================================
    @classmethod
    def getConfiguredGraphClass(cls):
        """
        In this class method, we must return a configured graph class
        """

    # ===============================================================================================================
    #DEFINING THE CLASS OF GRAPH WE USE
    DU_GRAPH = Graph_MultiSinglePageXml

    lLabels1 = ['RB', 'RI', 'RE', 'RS','RO']
    lIgnoredLabels1 = None
    # """
    # if you play with a toy collection, which does not have all expected classes, you can reduce those.
    # """
    # 
    # lActuallySeen = None
    # if lActuallySeen:
    #     print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING"
    #     lIgnoredLabels  = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen]
    #     lLabels         = [lLabels[i] for i in lActuallySeen ]
    #     print len(lLabels)          , lLabels
    #     print len(lIgnoredLabels)   , lIgnoredLabels
    #     nbClass = len(lLabels) + 1  #because the ignored labels will become OTHER
    
    nt1 = NodeType_PageXml_type_woText("text"                   #some short prefix because labels below are prefixed with it
                          , lLabels1
                          , lIgnoredLabels1
                          , False    #no label means OTHER
                          , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))  #we reduce overlap in this way
                          )
    nt1.setXpathExpr( (".//pc:TextLine"        #how to find the nodes
                      , "./pc:TextEquiv")       #how to get their text
                   )
    DU_GRAPH.addNodeType(nt1)
    
    nt2 = NodeType_PageXml_type_woText("sprtr"                   #some short prefix because labels below are prefixed with it
                          , ['SI', 'SO']
                          , None
                          , False    #no label means OTHER
                          , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3))  #we reduce overlap in this way
                          )
    nt2.setXpathExpr( (".//pc:SeparatorRegion"  #how to find the nodes
                      , "./pc:TextEquiv")       #how to get their text  (no text in fact)
                   )
    DU_GRAPH.addNodeType(nt2)    


    #=== CONFIGURATION ====================================================================
    def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): 
        
        #another way to specify the graph class
        # defining a  getConfiguredGraphClass is preferred
        self.configureGraphClass(self.DU_GRAPH)

        DU_CRF_Task.__init__(self
                     , sModelName, sModelDir
                     , dLearnerConfig = {
                                   'C'                : .1   if C               is None else C
                                 , 'njobs'            : 8    if njobs           is None else njobs
                                 , 'inference_cache'  : 50   if inference_cache is None else inference_cache
                                 #, 'tol'              : .1
                                 , 'tol'              : .05  if tol             is None else tol
                                 , 'save_every'       : 50     #save every 50 iterations,for warm start
                                 , 'max_iter'         : 1000 if max_iter        is None else max_iter
                         }
                     , sComment=sComment
                     , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText_v3
                     , dFeatureConfig = {
                         #config for the extractor of nodes of each type
                         "text": None,    
                         "sprtr": None,
                         #config for the extractor of edges of each type
                         "text_text": None,    
                         "text_sprtr": None,    
                         "sprtr_text": None,    
                         "sprtr_sprtr": None    
                         }
                     )
        
        traceln("- classes: ", self.DU_GRAPH.getLabelNameList())

        self.bsln_mdl = self.addBaseline_LogisticRegression()    #use a LR model trained by GridSearch as baseline
    
    #=== END OF CONFIGURATION =============================================================

  
    def predict(self, lsColDir,sDocId):
        """
        Return the list of produced files
        """
#         self.sXmlFilenamePattern = "*.a_mpxml"
        return DU_CRF_Task.predict(self, lsColDir,sDocId)