def inference_method(self, value):
     # If the value is being set to 'sample_db'
     # we need to make sure that the sqlite file
     # exists.
     if value == 'sample_db':
         ensure_data_dir_exists(self.sample_db_filename)
         sample_ordering = self.discover_sample_ordering()
         domains = dict([(var, var.domain) for var, _ in sample_ordering])
         if not os.path.isfile(self.sample_db_filename):
             # This is a new file so we need to
             # initialize the db...
             self.sample_db = SampleDB(self.sample_db_filename,
                                       domains,
                                       initialize=True)
         else:
             self.sample_db = SampleDB(self.sample_db_filename,
                                       domains,
                                       initialize=False)
     self._inference_method = value
 def inference_method(self, value):
     # If the value is being set to 'sample_db'
     # we need to make sure that the sqlite file
     # exists.
     if value == 'sample_db':
         ensure_data_dir_exists(self.sample_db_filename)
         sample_ordering = self.discover_sample_ordering()
         domains = dict([(var, var.domain) for var, _ in sample_ordering])
         if not os.path.isfile(self.sample_db_filename):
             # This is a new file so we need to
             # initialize the db...
             self.sample_db = SampleDB(
                 self.sample_db_filename,
                 domains,
                 initialize=True)
         else:
             self.sample_db = SampleDB(
                 self.sample_db_filename,
                 domains,
                 initialize=False)
     self._inference_method = value
class FactorGraph(object):

    def __init__(self, nodes, name=None, n_samples=100):
        self.nodes = nodes
        self._inference_method = 'sumproduct'
        # We need to divine the domains for Factor nodes here...
        # First compile a mapping of factors to variables
        # from the arg spec...
        function_args = dict()
        arg_domains = dict()
        for node in self.nodes:
            if isinstance(node, VariableNode):
                #if not hasattr(node, 'domain'):
                #    node.domain = [True, False]
                arg_domains[node.name] = node.domain
            elif isinstance(node, FactorNode):
                function_args[node.func.__name__] = get_args(node.func)
        # Now if the domains for the
        # factor functions have not been explicitely
        # set we create them based on the variable
        # values it can take.
        for node in self.nodes:
            if isinstance(node, FactorNode):
                if hasattr(node.func, 'domains'):
                    continue
                domains = dict()
                for arg in get_args(node.func):
                    if not arg in arg_domains:
                        print 'WARNING: missing variable for arg:%s' % arg
                    else:
                        domains.update({arg: arg_domains[arg]})
                node.func.domains = domains
        self.name = name
        self.n_samples = n_samples
        # Now try to set the mode of inference..
        try:
            if self.has_cycles():
                # Currently only sampling
                # is supported for cyclic graphs
                self.inference_method = 'sample'
            else:
                # The sumproduct method will
                # give exact likelihoods but
                # only of the graph contains
                # no cycles.
                self.inference_method = 'sumproduct'
        except:
            print 'Failed to determine if graph has cycles, '
            'setting inference to sample.'
            self.inference_method = 'sample'
        self.enforce_minimum_samples = False

    @property
    def inference_method(self):
        return self._inference_method

    @inference_method.setter
    def inference_method(self, value):
        # If the value is being set to 'sample_db'
        # we need to make sure that the sqlite file
        # exists.
        if value == 'sample_db':
            ensure_data_dir_exists(self.sample_db_filename)
            sample_ordering = self.discover_sample_ordering()
            domains = dict([(var, var.domain) for var, _ in sample_ordering])
            if not os.path.isfile(self.sample_db_filename):
                # This is a new file so we need to
                # initialize the db...
                self.sample_db = SampleDB(
                    self.sample_db_filename,
                    domains,
                    initialize=True)
            else:
                self.sample_db = SampleDB(
                    self.sample_db_filename,
                    domains,
                    initialize=False)
        self._inference_method = value

    @property
    def sample_db_filename(self):
        '''
        Get the name of the sqlite sample
        database for external sample
        generation and querying.
        The default location for now
        will be in the users home
        directory under ~/.pypgm/data/[name].sqlite
        where [name] is the name of the
        model. If the model has
        not been given an explict name
        it will be "default".

        '''
        home = os.path.expanduser('~')
        return os.path.join(
            home, '.pypgm',
            'data',
            '%s.sqlite' % (self.name or 'default'))

    def reset(self):
        '''
        Reset all nodes back to their initial state.
        We should do this before or after adding
        or removing evidence.
        '''
        for node in self.nodes:
            node.reset()

    def has_cycles(self):
        '''
        Check if the graph has cycles or not.
        We will do this by traversing starting
        from any leaf node and recording
        both the edges traversed and the nodes
        discovered. From stackoverflow, if
        an unexplored edge leads to a
        previously found node then it has
        cycles.
        '''
        discovered_nodes = set()
        traversed_edges = set()
        q = Queue()
        for node in self.nodes:
            if node.is_leaf():
                start_node = node
                break
        q.put(start_node)
        while not q.empty():
            current_node = q.get()
            if DEBUG:
                print "Current Node: ", current_node
                print "Discovered Nodes before adding Current Node: ", \
                    discovered_nodes
            if current_node.name in discovered_nodes:
                # We have a cycle!
                if DEBUG:
                    print 'Dequeued node already processed: %s', current_node
                return True
            discovered_nodes.add(current_node.name)
            if DEBUG:
                print "Discovered Nodes after adding Current Node: ", \
                    discovered_nodes
            for neighbour in current_node.neighbours:
                edge = [current_node.name, neighbour.name]
                # Since this is undirected and we want
                # to record the edges we have traversed
                # we will sort the edge alphabetically
                edge.sort()
                edge = tuple(edge)
                if edge not in traversed_edges:
                    # This is a new edge...
                    if neighbour.name in discovered_nodes:
                        return True
                # Now place all neighbour nodes on the q
                # and record this edge as traversed
                if neighbour.name not in discovered_nodes:
                    if DEBUG:
                        print 'Enqueuing: %s' % neighbour
                    q.put(neighbour)
                traversed_edges.add(edge)
        return False

    def verify(self):
        '''
        Check several properties of the Factor Graph
        that should hold.
        '''
        # Check that all nodes are either
        # instances of classes derived from
        # VariableNode or FactorNode.
        # It is a very common error to instantiate
        # the graph with the factor function
        # instead of the corresponding factor
        # node.
        for node in self.nodes:
            if not isinstance(node, VariableNode) and \
                    not isinstance(node, FactorNode):
                bases = node.__class__.__bases__
                if not VariableNode in bases and not FactorNode in bases:
                    print ('Factor Graph does not '
                           'support nodes of type: %s' % node.__class__)
                    raise InvalidGraphException
        # First check that for each node
        # only connects to nodes of the
        # other type.
        print 'Checking neighbour node types...'
        for node in self.nodes:
            if not node.verify_neighbour_types():
                print '%s has invalid neighbour type.' % node
                return False
        print 'Checking that all factor functions have domains...'
        for node in self.nodes:
            if isinstance(node, FactorNode):
                if not hasattr(node.func, 'domains'):
                    print '%s has no domains.' % node
                    raise InvalidGraphException
                elif not node.func.domains:
                    # Also check for an empty domain dict!
                    print '%s has empty domains.' % node
                    raise InvalidGraphException
        print 'Checking that all variables are accounted for' + \
            ' by at least one function...'
        variables = set([vn.name for vn in self.nodes
                         if isinstance(vn, VariableNode)])

        largs = [get_args(fn.func) for fn in
                 self.nodes if isinstance(fn, FactorNode)]

        args = set(reduce(lambda x, y: x + y, largs))

        if not variables.issubset(args):
            print 'These variables are not used in any factors nodes: '
            print variables.difference(args)
            return False
        print 'Checking that all arguments have matching variable nodes...'
        if not args.issubset(variables):
            print 'These arguments have missing variables:'
            print args.difference(variables)
            return False
        print 'Checking that graph has at least one leaf node...'
        leaf_nodes = filter(
            lambda x: x.is_leaf(),
            self.nodes)
        if not leaf_nodes:
            print 'Graph has no leaf nodes.'
            raise InvalidGraphException
        return True

    def get_leaves(self):
        return [node for node in self.nodes if node.is_leaf()]

    def get_eligible_senders(self):
        '''
        Return a list of nodes that are
        eligible to send messages at this
        round. Only nodes that have received
        messages from all but one neighbour
        may send at any round.
        '''
        eligible = []
        for node in self.nodes:
            if node.get_target():
                eligible.append(node)
        return eligible

    def propagate(self):
        '''
        This is the heart of the sum-product
        Message Passing Algorithm.
        '''
        step = 1
        while True:
            eligible_senders = self.get_eligible_senders()
            #print 'Step: %s %s nodes can send.' \
            # % (step, len(eligible_senders))
            #print [x.name for x in eligible_senders]
            if not eligible_senders:
                break
            for node in eligible_senders:
                message = node.construct_message()
                node.send(message)
            step += 1

    def variable_nodes(self):
        return [n for n in self.nodes if isinstance(n, VariableNode)]

    def factor_nodes(self):
        return [n for n in self.nodes if isinstance(n, FactorNode)]

    def get_normalizer(self):
        for node in self.variable_nodes():
            if node.value is not None:
                normalizer = node.marginal(node.value)
                return normalizer
        return 1

    def status(self, omit=[False, 0]):
        normalizer = self.get_normalizer()
        retval = dict()
        for node in self.variable_nodes():
            for value in node.domain:
                m = node.marginal(value, normalizer)
                retval[(node.name, value)] = m
        return retval

    def query_by_propagation(self, **kwds):
        self.reset()
        for k, v in kwds.items():
            for node in self.variable_nodes():
                if node.name == k:
                    add_evidence(node, v)
        self.propagate()
        return self.status()

    def query(self, **kwds):
        if self.inference_method == 'sample_db':
            return self.query_by_external_samples(**kwds)
        elif self.inference_method == 'sample':
            return self.query_by_sampling(**kwds)
        elif self.inference_method == 'sumproduct':
            return self.query_by_propagation(**kwds)
        raise InvalidInferenceMethod

    def q(self, **kwds):
        '''Wrapper around query

        This method formats the query
        result in a nice human readable format
        for interactive use.
        '''
        result = self.query(**kwds)
        tab = PrettyTable(['Node', 'Value', 'Marginal'], sortby='Node')
        tab.align = 'l'
        tab.align['Marginal'] = 'r'
        tab.float_format = '%8.6f'
        for (node, value), prob in result.items():
            if kwds.get(node, '') == value:
                tab.add_row(['%s*' % node,
                             '%s%s*%s' % (GREEN, value, NORMAL),
                             '%8.6f' % prob])
            else:
                tab.add_row([node, value, '%8.6f' % prob])
        print tab

    def discover_sample_ordering(self):
        return discover_sample_ordering(self)

    def get_sample(self, evidence={}):
        '''
        We need to allow for setting
        certain observed variables and
        discarding mismatching
        samples as we generate them.
        '''
        if not hasattr(self, 'sample_ordering'):
            self.sample_ordering = self.discover_sample_ordering()
        return get_sample(self.sample_ordering, evidence)

    def query_by_sampling(self, **kwds):
        counts = defaultdict(int)
        valid_samples = 0
        while valid_samples < self.n_samples:
            print "%s of %s" % (valid_samples, self.n_samples)
            try:
                sample = self.get_sample(kwds)
                valid_samples += 1
            except:
                print 'Failed to get a valid sample...'
                print 'continuing...'
                continue
            for var in sample:
                key = (var.name, var.value)
                counts[key] += 1
        # Now normalize
        normalized = dict(
            [(k, v / valid_samples) for k, v in counts.items()])
        return normalized

    def generate_samples(self, n):
        '''
        Generate and save samples to
        the SQLite sample db for this
        model.
        '''
        if self.inference_method != 'sample_db':
            raise IncorrectInferenceMethodError(
                'generate_samples() not support for inference method: %s' % \
                self.inference_method)
        valid_samples = 0
        if not hasattr(self, 'sample_ordering'):
            self.sample_ordering = self.discover_sample_ordering()
        fn = [x[0].name for x in self.sample_ordering]
        sdb = self.sample_db
        while valid_samples < n:
            try:
                sample = self.get_sample()
            except InvalidSampleException:
                # TODO: Need to figure
                # out why we get invalid
                # samples.
                continue
            sdb.save_sample([(v.name, v.value) for v in sample])
            valid_samples += 1
        sdb.commit()
        print '%s samples stored in %s' % (n, self.sample_db_filename)

    def query_by_external_samples(self, **kwds):
        counts = defaultdict(int)
        samples = self.sample_db.get_samples(self.n_samples, **kwds)
        if len(samples) == 0:
            raise NoSamplesInDB(
                'There are no samples in the database. '
                'Generate some with graph.generate_samples(N).')
        if len(samples) < self.n_samples and self.enforce_minimum_samples:
            raise InsufficientSamplesException(
                'There are less samples in the sampling '
                'database than are required by this graph. '
                'Either generate more samples '
                '(graph.generate_samples(N) or '
                'decrease the number of samples '
                'required for querying (graph.n_samples). ')
        for sample in samples:
            for name, val in sample.items():
                key = (name, val)
                counts[key] += 1
        normalized = dict(
            [(k, v / len(samples)) for k, v in counts.items()])
        return normalized

    def loadToNeo(self, connection=None):
        if connection:
            neoGraph = py2neo.Graph(connection)
        else:
            neoGraph = py2neo.Graph()

        # WARNING WARNING
        neoGraph.delete_all()

        # edges = set()
        vertices = dict() # used to capture the Neo Nodes
        for node in self.nodes:
            if isinstance(node, FactorNode):
                vertices[node.name] = py2neo.Node("FactorNode", name=node.name)
                # fh.write('  %s [ shape="rectangle" color="red"];\n' % node.name)
            else:
                vertices[node.name] = py2neo.Node("VariableNode", name=node.name)
                # fh.write('  %s [ shape="ellipse" color="blue"];\n' % node.name)
        # Build the edges
        for node in self.nodes:
            for neighbour in node.neighbours:
                # edge = [node.name, neighbour.name]
                # edge = tuple(sorted(edge))

                edge = py2neo.Relationship(vertices[node.name], "INFLUENCES", vertices[neighbour.name], since=1999)
                # edges.add(edge)
                neoGraph.create(edge)

        # for source, target in edges:
        #     fh.write('  %s -- %s;\n' % (source, target))

    def export(self, filename=None, format='graphviz'):
        '''Export the graph in GraphViz dot language.'''
        if filename:
            fh = open(filename, 'w')
        else:
            fh = sys.stdout
        if format != 'graphviz':
            raise 'Unsupported Export Format.'
        fh.write('graph G {\n')
        fh.write('  graph [ dpi = 300 bgcolor="transparent" rankdir="LR"];\n')
        edges = set()
        for node in self.nodes:
            if isinstance(node, FactorNode):
                fh.write('  %s [ shape="rectangle" color="red"];\n' % node.name)
            else:
                fh.write('  %s [ shape="ellipse" color="blue"];\n' % node.name)
        for node in self.nodes:
            for neighbour in node.neighbours:
                edge = [node.name, neighbour.name]
                edge = tuple(sorted(edge))
                edges.add(edge)
        for source, target in edges:
            fh.write('  %s -- %s;\n' % (source, target))
        fh.write('}\n')
class FactorGraph(object):
    def __init__(self, nodes, name=None, n_samples=100):
        self.nodes = nodes
        self._inference_method = 'sumproduct'
        # We need to divine the domains for Factor nodes here...
        # First compile a mapping of factors to variables
        # from the arg spec...
        function_args = dict()
        arg_domains = dict()
        for node in self.nodes:
            if isinstance(node, VariableNode):
                #if not hasattr(node, 'domain'):
                #    node.domain = [True, False]
                arg_domains[node.name] = node.domain
            elif isinstance(node, FactorNode):
                function_args[node.func.__name__] = get_args(node.func)
        # Now if the domains for the
        # factor functions have not been explicitely
        # set we create them based on the variable
        # values it can take.
        for node in self.nodes:
            if isinstance(node, FactorNode):
                if hasattr(node.func, 'domains'):
                    continue
                domains = dict()
                for arg in get_args(node.func):
                    if not arg in arg_domains:
                        print('WARNING: missing variable for arg:%s' % arg)
                    else:
                        domains.update({arg: arg_domains[arg]})
                node.func.domains = domains
        self.name = name
        self.n_samples = n_samples
        # Now try to set the mode of inference..
        try:
            if self.has_cycles():
                # Currently only sampling
                # is supported for cyclic graphs
                self.inference_method = 'sample'
            else:
                # The sumproduct method will
                # give exact likelihoods but
                # only of the graph contains
                # no cycles.
                self.inference_method = 'sumproduct'
        except:
            print('Failed to determine if graph has cycles, ')
            'setting inference to sample.'
            self.inference_method = 'sample'
        self.enforce_minimum_samples = False

    @property
    def inference_method(self):
        return self._inference_method

    @inference_method.setter
    def inference_method(self, value):
        # If the value is being set to 'sample_db'
        # we need to make sure that the sqlite file
        # exists.
        if value == 'sample_db':
            ensure_data_dir_exists(self.sample_db_filename)
            sample_ordering = self.discover_sample_ordering()
            domains = dict([(var, var.domain) for var, _ in sample_ordering])
            if not os.path.isfile(self.sample_db_filename):
                # This is a new file so we need to
                # initialize the db...
                self.sample_db = SampleDB(self.sample_db_filename,
                                          domains,
                                          initialize=True)
            else:
                self.sample_db = SampleDB(self.sample_db_filename,
                                          domains,
                                          initialize=False)
        self._inference_method = value

    @property
    def sample_db_filename(self):
        '''
        Get the name of the sqlite sample
        database for external sample
        generation and querying.
        The default location for now
        will be in the users home
        directory under ~/.pypgm/data/[name].sqlite
        where [name] is the name of the
        model. If the model has
        not been given an explict name
        it will be "default".

        '''
        home = os.path.expanduser('~')
        return os.path.join(home, '.pypgm', 'data',
                            '%s.sqlite' % (self.name or 'default'))

    def reset(self):
        '''
        Reset all nodes back to their initial state.
        We should do this before or after adding
        or removing evidence.
        '''
        for node in self.nodes:
            node.reset()

    def has_cycles(self):
        '''
        Check if the graph has cycles or not.
        We will do this by traversing starting
        from any leaf node and recording
        both the edges traversed and the nodes
        discovered. From stackoverflow, if
        an unexplored edge leads to a
        previously found node then it has
        cycles.
        '''
        discovered_nodes = set()
        traversed_edges = set()
        q = Queue()
        for node in self.nodes:
            if node.is_leaf():
                start_node = node
                break
        q.put(start_node)
        while not q.empty():
            current_node = q.get()
            if DEBUG:
                print("Current Node: ", current_node)
                print("Discovered Nodes before adding Current Node: ", \
                    discovered_nodes)
            if current_node.name in discovered_nodes:
                # We have a cycle!
                if DEBUG:
                    print('Dequeued node already processed: %s', current_node)
                return True
            discovered_nodes.add(current_node.name)
            if DEBUG:
                print("Discovered Nodes after adding Current Node: ", \
                    discovered_nodes)
            for neighbour in current_node.neighbours:
                edge = [current_node.name, neighbour.name]
                # Since this is undirected and we want
                # to record the edges we have traversed
                # we will sort the edge alphabetically
                edge.sort()
                edge = tuple(edge)
                if edge not in traversed_edges:
                    # This is a new edge...
                    if neighbour.name in discovered_nodes:
                        return True
                # Now place all neighbour nodes on the q
                # and record this edge as traversed
                if neighbour.name not in discovered_nodes:
                    if DEBUG:
                        print('Enqueuing: %s' % neighbour)
                    q.put(neighbour)
                traversed_edges.add(edge)
        return False

    def verify(self):
        '''
        Check several properties of the Factor Graph
        that should hold.
        '''
        # Check that all nodes are either
        # instances of classes derived from
        # VariableNode or FactorNode.
        # It is a very common error to instantiate
        # the graph with the factor function
        # instead of the corresponding factor
        # node.
        for node in self.nodes:
            if not isinstance(node, VariableNode) and \
                    not isinstance(node, FactorNode):
                bases = node.__class__.__bases__
                if not VariableNode in bases and not FactorNode in bases:
                    print(('Factor Graph does not '
                           'support nodes of type: %s' % node.__class__))
                    raise InvalidGraphException
        # First check that for each node
        # only connects to nodes of the
        # other type.
        print('Checking neighbour node types...')
        for node in self.nodes:
            if not node.verify_neighbour_types():
                print('%s has invalid neighbour type.' % node)
                return False
        print('Checking that all factor functions have domains...')
        for node in self.nodes:
            if isinstance(node, FactorNode):
                if not hasattr(node.func, 'domains'):
                    print('%s has no domains.' % node)
                    raise InvalidGraphException
                elif not node.func.domains:
                    # Also check for an empty domain dict!
                    print('%s has empty domains.' % node)
                    raise InvalidGraphException
        print('Checking that all variables are accounted for' + \
            ' by at least one function...')
        variables = set(
            [vn.name for vn in self.nodes if isinstance(vn, VariableNode)])

        largs = [
            get_args(fn.func) for fn in self.nodes
            if isinstance(fn, FactorNode)
        ]

        args = set(reduce(lambda x, y: x + y, largs))

        if not variables.issubset(args):
            print('These variables are not used in any factors nodes: ')
            print(variables.difference(args))
            return False
        print('Checking that all arguments have matching variable nodes...')
        if not args.issubset(variables):
            print('These arguments have missing variables:')
            print(args.difference(variables))
            return False
        print('Checking that graph has at least one leaf node...')
        leaf_nodes = [x for x in self.nodes if x.is_leaf()]
        if not leaf_nodes:
            print('Graph has no leaf nodes.')
            raise InvalidGraphException
        return True

    def get_leaves(self):
        return [node for node in self.nodes if node.is_leaf()]

    def get_eligible_senders(self):
        '''
        Return a list of nodes that are
        eligible to send messages at this
        round. Only nodes that have received
        messages from all but one neighbour
        may send at any round.
        '''
        eligible = []
        for node in self.nodes:
            if node.get_target():
                eligible.append(node)
        return eligible

    def propagate(self):
        '''
        This is the heart of the sum-product
        Message Passing Algorithm.
        '''
        step = 1
        while True:
            eligible_senders = self.get_eligible_senders()
            #print 'Step: %s %s nodes can send.' \
            # % (step, len(eligible_senders))
            #print [x.name for x in eligible_senders]
            if not eligible_senders:
                break
            for node in eligible_senders:
                message = node.construct_message()
                node.send(message)
            step += 1

    def variable_nodes(self):
        return [n for n in self.nodes if isinstance(n, VariableNode)]

    def factor_nodes(self):
        return [n for n in self.nodes if isinstance(n, FactorNode)]

    def get_normalizer(self):
        for node in self.variable_nodes():
            if node.value is not None:
                normalizer = node.marginal(node.value)
                return normalizer
        return 1

    def status(self, omit=[False, 0]):
        normalizer = self.get_normalizer()
        retval = dict()
        for node in self.variable_nodes():
            for value in node.domain:
                m = node.marginal(value, normalizer)
                retval[(node.name, value)] = m
        return retval

    def query_by_propagation(self, **kwds):
        self.reset()
        for k, v in kwds.items():
            for node in self.variable_nodes():
                if node.name == k:
                    add_evidence(node, v)
        self.propagate()
        return self.status()

    def query(self, **kwds):
        if self.inference_method == 'sample_db':
            return self.query_by_external_samples(**kwds)
        elif self.inference_method == 'sample':
            return self.query_by_sampling(**kwds)
        elif self.inference_method == 'sumproduct':
            return self.query_by_propagation(**kwds)
        raise InvalidInferenceMethod

    def q(self, **kwds):
        '''Wrapper around query

        This method formats the query
        result in a nice human readable format
        for interactive use.
        '''
        result = self.query(**kwds)
        tab = PrettyTable(['Node', 'Value', 'Marginal'], sortby='Node')
        tab.align = 'l'
        tab.align['Marginal'] = 'r'
        tab.float_format = '%8.6f'
        for (node, value), prob in result.items():
            if kwds.get(node, '') == value:
                tab.add_row([
                    '%s*' % node,
                    '%s%s*%s' % (GREEN, value, NORMAL),
                    '%8.6f' % prob
                ])
            else:
                tab.add_row([node, value, '%8.6f' % prob])
        print(tab)

    def discover_sample_ordering(self):
        return discover_sample_ordering(self)

    def get_sample(self, evidence={}):
        '''
        We need to allow for setting
        certain observed variables and
        discarding mismatching
        samples as we generate them.
        '''
        if not hasattr(self, 'sample_ordering'):
            self.sample_ordering = self.discover_sample_ordering()
        return get_sample(self.sample_ordering, evidence)

    def query_by_sampling(self, **kwds):
        counts = defaultdict(int)
        valid_samples = 0
        while valid_samples < self.n_samples:
            print("%s of %s" % (valid_samples, self.n_samples))
            try:
                sample = self.get_sample(kwds)
                valid_samples += 1
            except:
                print('Failed to get a valid sample...')
                print('continuing...')
                continue
            for var in sample:
                key = (var.name, var.value)
                counts[key] += 1
        # Now normalize
        normalized = dict([(k, v / valid_samples) for k, v in counts.items()])
        return normalized

    def generate_samples(self, n):
        '''
        Generate and save samples to
        the SQLite sample db for this
        model.
        '''
        if self.inference_method != 'sample_db':
            raise IncorrectInferenceMethodError(
                'generate_samples() not support for inference method: %s' % \
                self.inference_method)
        valid_samples = 0
        if not hasattr(self, 'sample_ordering'):
            self.sample_ordering = self.discover_sample_ordering()
        fn = [x[0].name for x in self.sample_ordering]
        sdb = self.sample_db
        while valid_samples < n:
            try:
                sample = self.get_sample()
            except InvalidSampleException:
                # TODO: Need to figure
                # out why we get invalid
                # samples.
                continue
            sdb.save_sample([(v.name, v.value) for v in sample])
            valid_samples += 1
        sdb.commit()
        print('%s samples stored in %s' % (n, self.sample_db_filename))

    def query_by_external_samples(self, **kwds):
        counts = defaultdict(int)
        samples = self.sample_db.get_samples(self.n_samples, **kwds)
        if len(samples) == 0:
            raise NoSamplesInDB(
                'There are no samples in the database. '
                'Generate some with graph.generate_samples(N).')
        if len(samples) < self.n_samples and self.enforce_minimum_samples:
            raise InsufficientSamplesException(
                'There are less samples in the sampling '
                'database than are required by this graph. '
                'Either generate more samples '
                '(graph.generate_samples(N) or '
                'decrease the number of samples '
                'required for querying (graph.n_samples). ')
        for sample in samples:
            for name, val in sample.items():
                key = (name, val)
                counts[key] += 1
        normalized = dict([(k, v / len(samples)) for k, v in counts.items()])
        return normalized

    def export(self, filename=None, format='graphviz'):
        '''Export the graph in GraphViz dot language.'''
        if filename:
            fh = open(filename, 'w')
        else:
            fh = sys.stdout
        if format != 'graphviz':
            raise 'Unsupported Export Format.'
        fh.write('graph G {\n')
        fh.write('  graph [ dpi = 300 bgcolor="transparent" rankdir="LR"];\n')
        edges = set()
        for node in self.nodes:
            if isinstance(node, FactorNode):
                fh.write('  %s [ shape="rectangle" color="red"];\n' %
                         node.name)
            else:
                fh.write('  %s [ shape="ellipse" color="blue"];\n' % node.name)
        for node in self.nodes:
            for neighbour in node.neighbours:
                edge = [node.name, neighbour.name]
                edge = tuple(sorted(edge))
                edges.add(edge)
        for source, target in edges:
            fh.write('  %s -- %s;\n' % (source, target))
        fh.write('}\n')