Example #1
0
 def __init__(self, url):
     """Loads the data from a XBN file"""
     #empty BNet
     self.G = BNet()
     
     self.xbn = parse(url)
     self.version = ""
     node = self.xbn.childNodes
     ok = False
     #XBN version 1.0
     try:
         # get basic info on the BN
         bnmodel = node[0].childNodes        #<BNMODEL>
         statdynvar = bnmodel[1].childNodes  #children of bnmodel 
                                             #<STATICPROPERTIES>,<DYNAMICPROPERTIES>,<VARIABLES>
         stat = statdynvar[0].childNodes     #<STATICPROPERTIES>
         self.version = stat[2].childNodes[0].nodeValue           #<VERSION>        
         ok=True   
     except:
         pass
     
     #XBN version 0.2
     try:
         bnmodel = node[1].childNodes        #<BNMODEL>
         statdynvar = bnmodel[1].childNodes  #children of bnmodel 
                                             #<STATICPROPERTIES>,<DYNAMICPROPERTIES>,<VARIABLES>
         stat = statdynvar[1].childNodes     #<STATICPROPERTIES>
         attrs = stat[3].attributes          # ??? but it works, to get the version number
         self.version = attrs.get(attrs.keys()[0]).nodeValue     #<VERSION>           
         ok = True
     except:
         pass
         
     if not ok: raise 'Neither version 1.0 or 0.2, verify your xbn file...'
 def setUp(self):
     g = BNet('Water Sprinkler Bayesian Network')
     c, s, r, w = [g.add_vertex(DiscreteVertex(name, 2))
                   for name in 'c s r w'.split()]
     for start, end in [(c, r), (c, s), (r, w), (s, w)]:
         g.add_edge( (start, end))
     g.finalize()
     c.cpt.set_values([0.5, 0.5])
     s.cpt.set_values([0.5, 0.9, 0.5, 0.1])
     r.cpt.set_values([0.8, 0.2, 0.2, 0.8])
     w.cpt.set_values([1, 0.1, 0.1, 0.01, 0.0, 0.9,
                                  0.9, 0.99])
     self.c = c
     self.s = s
     self.r = r
     self.w = w
     self.network = g
Example #3
0
class LoadXBN:
    """ Loads the data from a XBN file
    
        >>> xbn = LoadXBN('WetGrass.xbn')
        >>> BNet = xbn.Load()
        
        BNet is a openbayes.bayesnet.BNet class
    """

    variablesList = []
    structureList = []
    distributionList = []
    
    def __init__(self, url):
        """Loads the data from a XBN file"""
        #empty BNet
        self.G = BNet()
        
        self.xbn = parse(url)
        self.version = ""
        node = self.xbn.childNodes
        ok = False
        #XBN version 1.0
        try:
            # get basic info on the BN
            bnmodel = node[0].childNodes        #<BNMODEL>
            statdynvar = bnmodel[1].childNodes  #children of bnmodel 
                                                #<STATICPROPERTIES>,<DYNAMICPROPERTIES>,<VARIABLES>
            stat = statdynvar[0].childNodes     #<STATICPROPERTIES>
            self.version = stat[2].childNodes[0].nodeValue           #<VERSION>        
            ok=True   
        except:
            pass
        
        #XBN version 0.2
        try:
            bnmodel = node[1].childNodes        #<BNMODEL>
            statdynvar = bnmodel[1].childNodes  #children of bnmodel 
                                                #<STATICPROPERTIES>,<DYNAMICPROPERTIES>,<VARIABLES>
            stat = statdynvar[1].childNodes     #<STATICPROPERTIES>
            attrs = stat[3].attributes          # ??? but it works, to get the version number
            self.version = attrs.get(attrs.keys()[0]).nodeValue     #<VERSION>           
            ok = True
        except:
            pass
            
        if not ok: raise 'Neither version 1.0 or 0.2, verify your xbn file...'
        
    def Load(self):
        self.getBnInfos()
        self.getStaticProperties()
        self.getDynamicProperties()
        self.getVariablesXbn()
        self.getStructureXbn()
        self.getDistribution()
        
        return self.G
    
    def getBnInfos(self):
        bn = BnInfos()
        node = self.xbn.childNodes
        
        if self.version == "1.0":
            attrs = node[0].attributes
        else: #version 0.2
            attrs = node[1].attributes
        
        for attrName in attrs.keys(): 
            attrNode = attrs.get(attrName)
            attrValue = attrNode.nodeValue
            
            if attrName == "NAME":
                bn.name = attrValue
                # not used in BNet class                
            
            elif attrName == "ROOT":
                bn.root = attrValue
                self.G.name = attrValue                
                        
        return bn
    
    def getStaticProperties(self):
        prop = StaticProperties()
        
        node = self.xbn.childNodes
        
        if self.version == "1.0":
            bnmodel = node[0].childNodes
            statdynvar = bnmodel[1].childNodes
            stat = statdynvar[0].childNodes
            
            for elem in stat:                
                if elem.nodeType == Node.ELEMENT_NODE:                    
                    if elem.nodeName == "FORMAT":
                        info = elem.childNodes
                        prop.format = info[0].nodeValue
                    
                    elif elem.nodeName == "VERSION":
                        info = elem.childNodes
                        prop.version = info[0].nodeValue
                    
                    elif elem.nodeName == "CREATOR":
                        info = elem.childNodes
                        prop.creator = info[0].nodeValue
            
        else: #version 0.2
            bnmodel = node[1].childNodes
        
            statdynvar = bnmodel[1].childNodes
            stat = statdynvar[1].childNodes
            
            for elem in stat:                
                if elem.nodeType == Node.ELEMENT_NODE:                    
                    if elem.nodeName == "FORMAT":
                        attrs = elem.attributes
                        prop.format = attrs.get(attrs.keys()[0]).nodeValue
                    
                    elif elem.nodeName == "VERSION":
                        attrs = elem.attributes
                        prop.version = attrs.get(attrs.keys()[0]).nodeValue
                    
                    elif elem.nodeName == "CREATOR":
                        attrs = elem.attributes
                        prop.creator = attrs.get(attrs.keys()[0]).nodeValue

        return prop
    
    def getDynamicProperties(self):
        prop = DynamicProperties()
        
        try:
            del prop.dynPropType [:]
            prop.dynProperty.clear()
            prop.dynPropXml.clear()
        except:
            pass
        
        node = self.xbn.childNodes
        
        if self.version == "1.0":
            bnmodel = node[0].childNodes
            statdynvar = bnmodel[1].childNodes
            dyn = statdynvar[2].childNodes
        
        else:
            bnmodel = node[1].childNodes
            statdynvar = bnmodel[1].childNodes               
            dyn = statdynvar[3].childNodes
        
        for elem in dyn:            
            if elem.nodeType == Node.ELEMENT_NODE:                
                if elem.nodeName == "PROPERTYTYPE":
                    dictDyn = {}
                    attrs = elem.attributes
                    
                    for attrName in attrs.keys(): 
                        attrNode = attrs.get(attrName)
                        attrValue = attrNode.nodeValue
                        
                        if attrName == "NAME":           
                            dictDyn["NAME"] = attrValue

                        elif attrName == "TYPE":
                            dictDyn["TYPE"] = attrValue

                        elif attrName == "ENUMSET":
                            dictDyn["ENUMSET"] = attrValue
                    
                    if self.version == "1.0":                        
                        for info in elem.childNodes:                            
                            if info.nodeName == "COMMENT":
                                dictDyn["COMMENT"] = info.childNodes[0].nodeValue
                    else:
                        comment = elem.childNodes
                        comText = comment[1].childNodes
                        dictDyn["COMMENT"] = comText[0].nodeValue

                    prop.dynPropType.append(dictDyn)
                    
                elif elem.nodeName == "PROPERTY":
                    if self.version == "1.0":
                        attrs = elem.attributes
                        attrValue = attrs.get(attrs.keys()[0]).nodeValue
                        prop.dynProperty[attrValue] = elem.childNodes[0].childNodes[0].nodeValue

                    else:
                        attrs = elem.attributes
                        value = elem.childNodes
                        valueText = value[1].childNodes
                        prop.dynProperty[attrs.get(attrs.keys()[0]).nodeValue] = valueText[0].nodeValue
                
                elif elem.nodeName == "PROPXML":
                    if self.version == "1.0":
                        attrs = elem.attributes
                        attrValue = attrs.get(attrs.keys()[0]).nodeValue
                        prop.dynPropXml[attrValue] = elem.childNodes[0].childNodes[0].nodeValue

                    else:
                        attrs = elem.attributes
                        value = elem.childNodes
                        valueText = value[1].childNodes
                        prop.dynPropXml[attrs.get(attrs.keys()[0]).nodeValue] = valueText[0].nodeValue

        return prop        
        
    def getVariablesXbn(self):
        self.variablesList = []
        
        node = self.xbn.childNodes
        if self.version == "1.0":
            bnmodel = node[0].childNodes
            statdynvar = bnmodel[1].childNodes
            variables = statdynvar[4].childNodes
        else:
            bnmodel = node[1].childNodes
            statdynvar = bnmodel[1].childNodes
            variables = statdynvar[5].childNodes

        for var in variables:
            if var.nodeType == Node.ELEMENT_NODE:
                v = Variables()
                v.stateName = []
                v.propertyNameValue = {}
                attrs = var.attributes

                for attrName in attrs.keys(): 
                    attrNode = attrs.get(attrName)
                    attrValue = attrNode.nodeValue

                    if attrName == "NAME":
                        v.name = attrValue
                    elif attrName == "TYPE":
                        v.type = attrValue
                    elif attrName == "XPOS":
                        v.xpos = attrValue
                    elif attrName == "YPOS":
                        v.ypos = attrValue

                for info in var.childNodes:                    
                    if info.nodeType == Node.ELEMENT_NODE:             
                        if (info.nodeName == "DESCRIPTION") or \
                           (info.nodeName == "FULLNAME"):
                            try:
                                v.description = info.childNodes[0].nodeValue
                            except:
                                v.description = ""
                            
                        elif info.nodeName == "STATENAME":
                            v.stateName.append(info.childNodes[0].nodeValue)
                            
                        elif (info.nodeName == "PROPERTY"):
                            attrsb = info.attributes
                            attrValueb = attrsb.get(attrsb.keys()[0]).nodeValue

                            if self.version == "1.0":
                                v.propertyNameValue[attrValueb] = info.childNodes[0].childNodes[0].nodeValue
                            else:
                                v.propertyNameValue[attrValueb] = info.childNodes[1].childNodes[0].nodeValue

                self.variablesList.append(v)
                
        # create the corresponding nodes into the BNet class
        for v in self.variablesList:
            #---TODO: Discrete or Continuous. Here True means always discrete
            bv = BVertex(v.name, True, len(v.stateName))
            bv.state_names = v.stateName
            self.G.add_vertex(bv)
            #---TODO: add the names of the states into the vertex
                
        return self.variablesList
    
    def getStructureXbn(self):
        self.structureList = []
        
        node = self.xbn.childNodes
        
        if self.version == "1.0":
            bnmodel = node[0].childNodes
            statdynstruct = bnmodel[1].childNodes
            structure = statdynstruct[6].childNodes
        
        else:
            bnmodel = node[1].childNodes
            statdynstruct = bnmodel[1].childNodes
            structure = statdynstruct[7].childNodes
        
        for arc in structure:            
            if arc.nodeType == Node.ELEMENT_NODE:
                attrs = arc.attributes
                                            
                for attrName in attrs.keys():
                    attrNode = attrs.get(attrName)
                    attrValue = attrNode.nodeValue
                    
                    if attrName == "PARENT":
                        self.structureList.append(attrValue)
                    
                    elif attrName == "CHILD":
                        self.structureList.append(attrValue)

        for ind in range(0, len(self.structureList), 2):
            par = self.structureList[ind]
            child = self.structureList[ind + 1]
            self.G.add_edge((par, child))

        # initialize the distributions
        self.G.init_distributions()

        return self.structureList

    def getDistribution(self):
        self.distributionList = []
        
        node = self.xbn.childNodes
        
        if self.version == "1.0":
            bnmodel = node[0].childNodes
            statdyndist = bnmodel[1].childNodes
            distribution = statdyndist[8].childNodes
        
        else:
            bnmodel = node[1].childNodes
            statdyndist = bnmodel[1].childNodes               
            distribution = statdyndist[9].childNodes
        
        for dist in distribution:
            d = Distribution()
            d.condelem = []
            d.dpiIndex = []
            d.dpiData = []
            
            if dist.nodeType == Node.ELEMENT_NODE:
                attrs = dist.attributes
                
                for attrName in attrs.keys(): 
                    attrNode = attrs.get(attrName)
                    attrValue = attrNode.nodeValue

                    if attrName == "TYPE":
                        d.type = attrValue
            
            for distInfos in dist.childNodes:                
                if distInfos.nodeType == Node.ELEMENT_NODE:
                    if distInfos.nodeName == "CONDSET":                        
                        for elem in distInfos.childNodes:                            
                            if elem.nodeType == Node.ELEMENT_NODE:
                                attrsb = elem.attributes
                                d.condelem.append(attrsb.get(attrsb.keys()[0]).nodeValue)

                    elif distInfos.nodeName == "PRIVATE":                        
                        if distInfos.nodeType == Node.ELEMENT_NODE:
                            attrsb = distInfos.attributes
                            d.name = attrsb.get(attrsb.keys()[0]).nodeValue

                    elif distInfos.nodeName == "DPIS":                        
                        for dpi in distInfos.childNodes:
                            if dpi.nodeName == "DPI":
                                d.dpiData.append(dpi.childNodes[0].nodeValue)
                                attrs = dpi.attributes
                                
                                for attrName in attrs.keys(): 
                                    attrNode = attrs.get(attrName)
                                    attrValue = attrNode.nodeValue

                                    if attrName == "INDEXES":
                                        d.dpiIndex.append(attrValue)

            if dist.nodeType == Node.ELEMENT_NODE:                                   
                self.distributionList.append(d)
                
        for d in self.distributionList:
            dist = self.G.v[d.name].distribution # the distribution class into the BNet
            
            #---TODO: what about gaussians ???
            dist.distribution_type = 'Multinomial'
            
            if d.type == 'ci':
                # conditionally independant values are defined
                # fill the matrix with the conditionally independant term
                new = array([float(da) for da in d.dpiData[0].split()], type='Float32') # transform a string into a numarray
                for pa in dist.family[1:]:
                    new = new[..., newaxis]
                    n_states = pa.nvalues # number of states for each parent
                    new = concatenate([new] * n_states, axis=-1)
                    
                # replace all values in the distribution with the ci values
                dist[:] = new            

            if len(d.dpiIndex):
                # when multiple elements (nodes with parents)
                for data, index in zip(d.dpiData,d.dpiIndex):
                    # data, index are strings containing the data and index
                    ii = tuple([int(i) for i in index.split()]) # transform the string into a tuple of integers
                    
                    # create a dictionnary with the name of the dimension and the value it takes
                    dictin = {}     # e.g. dictin = {'Alternator':1,'FanBelt':0}
                    for pa, iii in zip(d.condelem, ii):
                        dictin[pa] = iii
                        
                    dd = array([float(da) for da in data.split()], type='Float32') # transform a string into a numarray
                    dist[dictin] = dd
                
            else:
                # for nodes with no parents
                # simply insert the data into the matrix
                dd = array([float(da) for da in d.dpiData[0].split()], type='Float32')
                dist[:] = dd
            
        return self.distributionList