Beispiel #1
0
    def _define_sampler_blueprints(self, bn):
        g = bn.DAG
        self.LocalSamplers = dict()
        self.SharedSamplers = dict()

        af = set.union(self.__was_fixed, self.FixedNodes)

        # If all parent nodes have been fixed without local changes -> f, f
        # If all parent nodes have been fixed but have local changes -> s, f
        # If any parent nodes is still floating -> c, c
        for d in set.union(self.FloatingNodes, self.__will_be_floating):
            loci = bn[d]
            pars = set(loci.Parents)

            if pars <= af:  # if all parent nodes had been fixed before
                self.LocalSamplers[d] = ActorBlueprint(d, ActorBlueprint.Frozen, pars, None)
                if set.intersection(pars, self.FixedNodes):

                    self.SharedSamplers[d] = ActorBlueprint(d, ActorBlueprint.Single, pars, None)
                else:
                    self.SharedSamplers[d] = self.LocalSamplers[d]
            else:
                req = dag.minimal_requirements(g, d, af)
                to_read = [n for n in req if n in af or bn.is_exogenous(n)]
                to_sample = [n for n in req if n not in to_read]
                actor = ActorBlueprint(d, ActorBlueprint.Compound, to_read, to_sample)
                self.LocalSamplers[d] = self.SharedSamplers[d] = actor

        for ch in self.__children.values():
            ch._define_sampler_blueprints(bn)
Beispiel #2
0
 def is_deterministic(self, node, given=None):
     if isinstance(self[node], ValueLoci):
         return True
     if given:
         req = minimal_requirements(self.DAG, node, given)
         req = [d for d in req if d not in given]
         return all(isinstance(self[d], ValueLoci) for d in req)
     else:
         return False
Beispiel #3
0
    def has_randomness(self, node, given=None):
        if self.is_rv(node):
            return True
        if given:
            req = minimal_requirements(self.DAG, node, given)
            req = [d for d in req if d not in given]
        else:
            req = self.DAG.ancestors(node)

        for d in req:
            if self.is_rv(d):
                return True
        else:
            return False
Beispiel #4
0
 def test_min_needs(self):
     self.assertCountEqual(dag.minimal_requirements(self.G, 'B', []), ['A'])
     self.assertCountEqual(dag.minimal_requirements(self.G, 'D', []),
                           ['A', 'B', 'C'])
     self.assertCountEqual(dag.minimal_requirements(self.G, 'D', ['B']),
                           ['B', 'C'])
Beispiel #5
0
    def _resolve_local_nodes(self, bn, fixed=None):
        """
        Identify mediators and find requirements
        :param bn: source bayesian network
        """
        g = bn.DAG
        mini = dag.minimal_dag(g, set.union(self.__as_fixed, self.__as_floating)).order()

        med = set(mini)
        med.difference_update(self.__as_fixed)
        med.difference_update(self.__as_floating)
        med = g.sort(med)

        self.FloatingNodes = set(self.__as_floating)
        self.FixedNodes = set(self.__as_fixed)

        # pass down fixed
        if self.__parent:
            self.__was_fixed = fixed if fixed else set()
        else:
            self.__was_fixed = set()
        all_fixed = set.union(self.__was_fixed, self.FixedNodes)

        for d in med:
            if bn.has_randomness(d, all_fixed):
                self.FloatingNodes.add(d)
            else:
                self.FixedNodes.add(d)
                all_fixed.add(d)

        for d in bn.sort(self.FloatingNodes):
            if not bn.has_randomness(d, all_fixed):
                self.FixedNodes.add(d)
                self.FloatingNodes.remove(d)
                all_fixed.add(d)

        if not self.__parent:
            for d in bn.Roots:
                if not bn.has_randomness(d):
                    self.FixedNodes.add(d)
                    all_fixed.add(d)

        rqs = dict()
        # requirements for floating nodes (giving values when needed)
        self.ListeningNodes = set()
        # requirements for fixed nodes (giving values at initialisation)
        self.ExoNodes = set()

        for i, node in enumerate(mini):
            par = mini[:i]
            rq = dag.minimal_requirements(g, node, par)

            if node in self.FloatingNodes:
                self.ListeningNodes.update(rq)
            else:
                self.ExoNodes.update(rq)
            rqs[node] = rq
        self.ListeningNodes.difference_update(mini)
        self.ExoNodes.difference_update(self.FixedNodes)

        for ch in self.__children.values():
            ch._resolve_local_nodes(bn, all_fixed)