def __function_type(node,cpt):
    type(cpt) is FunctionType
    # functions that specify cpts cannot have free variables
    freevars = cpt.__code__.co_freevars
    u.check(not freevars,
        f'cpt function for node {node.name} has free variables, {u.unpack(freevars)}',
        f'specifying TBN cpt using python code')
    # let's find out whether this is a functional or constrained cpt
    parameters = signature(cpt).parameters
    free_parameter_count = 0
    has_star_args        = False
    for p in parameters.values():
        sp = str(p)
        assert '**' not in sp
        if '*' in sp: 
            assert free_parameter_count == 0 # *args must be first
            has_star_args = True
        elif '=' not in sp: free_parameter_count += 1
    parent_count = len(node.parents)
    if (not has_star_args and free_parameter_count == parent_count) or \
       (has_star_args and free_parameter_count==0):
        return 'function'
    assert (not has_star_args and free_parameter_count == parent_count + 1) or \
           (has_star_args and free_parameter_count == 1)
    return 'constraint'
def expand(node,cpt,cpt_type):
    normalized = lambda cpt: np.allclose(1.,np.sum(cpt,axis=-1)) 
    tabular    = True # whether original cpt is a list or np array
    
    if type(cpt) is list:
        cpt = np.array(cpt)
    elif type(cpt) is FunctionType:
        tabular = False
        fn_type = __function_type(node,cpt)
        if fn_type == 'function': 
            cpt = __expand_fcpt(node,cpt,node.parents)
        else:
            assert fn_type == 'constraint'
            cpt = __expand_ccpt(node,cpt,node.parents)
    else:
        u.check(type(cpt) is np.array,
            f'{cpt_type} of node {node.name} is not a list, np array or python function:\n  {cpt}',
            f'specifying TBN cpt')
            
    assert type(cpt) is np.ndarray
    assert cpt.shape == node.shape()
    u.check(normalized(cpt),
            f'{cpt_type} of node {node.name} is not normalized:\n  {cpt}',
            f'specifying TBN cpt')
            
    cpt.flags.writeable = False # read only
    return cpt, tabular # np array
Esempio n. 3
0
    def train(self, model: nn.Module,
              data_dict: Dict[str, BaseModel.Dataset]) -> NoReturn:
        main_metric_results, dev_results, test_results = list(), list(), list()
        self._check_time(start=True)
        try:
            for epoch in range(self.epoch):
                # Fit
                self._check_time()
                loss = self.fit(model, data_dict['train'], epoch=epoch + 1)
                training_time = self._check_time()

                # Observe selected tensors
                if len(
                        model.check_list
                ) > 0 and self.check_epoch > 0 and epoch % self.check_epoch == 0:
                    utils.check(model.check_list)

                # Record dev and test results
                dev_result = self.evaluate(model, data_dict['dev'],
                                           self.topk[:1], self.metrics)
                test_result = self.evaluate(model, data_dict['test'],
                                            self.topk[:1], self.metrics)
                testing_time = self._check_time()
                dev_results.append(dev_result)
                test_results.append(test_result)
                main_metric_results.append(dev_result[self.main_metric])

                logging.info(
                    "Epoch {:<5} loss={:<.4f} [{:<.1f} s]\t dev=({}) test=({}) [{:<.1f} s] "
                    .format(epoch + 1, loss, training_time,
                            utils.format_metric(dev_result),
                            utils.format_metric(test_result), testing_time))

                # Save model and early stop
                if max(main_metric_results) == main_metric_results[-1] or \
                        (hasattr(model, 'stage') and model.stage == 1):
                    model.save_model()
                if self.early_stop and self.eval_termination(
                        main_metric_results):
                    logging.info("Early stop at %d based on dev result." %
                                 (epoch + 1))
                    break
        except KeyboardInterrupt:
            logging.info("Early stop manually")
            exit_here = input(
                "Exit completely without evaluation? (y/n) (default n):")
            if exit_here.lower().startswith('y'):
                logging.info(os.linesep + '-' * 45 + ' END: ' +
                             utils.get_time() + ' ' + '-' * 45)
                exit(1)

        # Find the best dev result across iterations
        best_epoch = main_metric_results.index(max(main_metric_results))
        logging.info(
            os.linesep +
            "Best Iter(dev)={:>5}\t dev=({}) test=({}) [{:<.1f} s] ".format(
                best_epoch + 1, utils.format_metric(dev_results[best_epoch]),
                utils.format_metric(test_results[best_epoch]), self.time[1] -
                self.time[0]))
        model.load_model()
    def __prepare_for_inference(self):

        # the following attributes are updated in decouple.py, which replicates
        # functional cpts and handles nodes with hard evidence, creating clones
        # of nodes in the process (clones are added to another 'decoupled' network)
        self._original = None  # tbn node cloned by this one
        self._master = None  # exactly one clone is declared as master
        self._clamped = False  # whether tbn node has hard evidence

        # the following attributes with _cpt, _cpt1, _cpt2 are updated in cpt.y
        self._values_org = self.values  # original node values before pruning
        self._card_org = self.card  # original node cardinality before pruning
        self._values_idx = None  # indices of unpruned values, if pruning happens

        # -process node and its cpts
        # -prune node values & parents and expand/prune cpts into tabular form
        tbn.cpt.set_cpts(self)

        # the following attributes will be updated next
        self._all01_cpt = None  # whether cpt is 0/1 (not applicable for testing nodes)
        self._cpt_label = None  # for saving to file (updated when processing cpts)

        # identify 0/1 cpts
        if self.testing:
            # selected cpt is not necessarily all zero-one even if cpt1 and cpt2 are
            self._all01_cpt = False
        else:
            self._all01_cpt = np.all(
                np.logical_or(self.cpt == 0, self.cpt == 1))
            u.check(
                not (self.fixed_cpt and self._functional) or self._all01_cpt,
                f'node {self.name} is declared functional but its fixed cpt is not functional',
                f'specifying TBN node')

        # -pruning node values or parents changes the shape of cpt for node
        # -a set of tied cpts may end up having different shapes due to pruning
        # -we create refined ties between groups that continue to have the same shape
        """ this is not really proper and needs to be updated """
        if self.cpt_tie is not None:
            #            s = '.'.join([str(hash(n.values)) for n in self.family])
            self._cpt_tie = f'{self.cpt_tie}__{self.shape()}'

        self.__set_cpt_labels()

        # we need to sort parents & family and also adjust the cpt accordingly
        # this must be done after processing cpts which may prune parents
        self.__sort()
        assert u.sorted(u.map('id', self.parents))
        assert u.sorted(u.map('id', self.family))
 def elm_order(self, solver, wait):
     u.show(f'    calling {solver}...', end='')
     graph_fname = 'decompose/tmp/graph.gr'
     tree_fname = 'decompose/tmp/tree.td'
     if solver == 'flow cutter':
         program = 'flow_cutter_pace17'
         cmd = [f'./decompose/solvers/{program}']
         online = True
     elif solver == 'tamaki heuristic':
         program = 'tamaki/tw-heuristic'
         cmd = [f'./decompose/solvers/{program}']
         online = True
     elif solver == 'tamaki exact':
         program = 'tamaki/tw-exact'
         cmd = [f'./decompose/solvers/{program}']
         online = False
     # write graph to file
     self.write(graph_fname)
     # call tree decomposition program
     with open(f'{graph_fname}', "r") as input, open(f'{tree_fname}',
                                                     "w") as output:
         process = subprocess.Popen(cmd, stdin=input, stdout=output)
         if online:
             u.show(f'waiting {wait} sec...', end='', flush=True)
             sleep(wait)
             process.send_signal(signal.SIGTERM)
         else:
             process.wait()  # blocks python until process returns
     code = process.returncode
     _, error = process.communicate()
     process.kill()
     u.check(code != 0, f'failed to execute {solver} because\n  {error}',
             f'using treewidth solver')
     u.show('done')
     # read decomposition tree from file
     tree = TreeD(tree_fname)
     # convert decomposition tree to elimination order (vertices)
     vertex_order = tree.elm_order()
     # return elimination order of tbn nodes
     stats = f'elm order: cls max {tree.width}'
     return self.vertices2nodes(vertex_order), tree.width, stats