def fg_inference(compile_fg): print('----------------------------------------------------------------') print('Weight inference') print('----------------------------------------------------------------') weight, variable, factor, ftv, domain_mask, n_edges = compile_fg fg = NumbSkull( n_inference_epoch=1000, n_learning_epoch=1000, stepsize=0.01, decay=0.95, reg_param=1e-6, regularization=2, truncation=10, quiet=(not False), verbose=False, learn_non_evidence=False, # need to test sample_evidence=False, burn_in=10, nthreads=1) fg.loadFactorGraph(weight, variable, factor, ftv, domain_mask, n_edges) fg.inference(out=True) for i in range(len(variable)): if fg.factorGraphs[0].marginals[i] > 0.5: variable[i]['initialValue'] = 1 else: variable[i]['initialValue'] = 0 weight_value = fg.factorGraphs[0].weight_value[0] return weight_value
def marginals(self, V, cardinality, L, L_offset, deps=(), init_acc=1.0, init_deps=1.0, init_class_prior=-1.0, epochs=100, step_size=None, decay=0.99, verbose=False, burn_in=50, timer=None): if self.weights is None: raise ValueError( "Must fit model with train() before computing marginal probabilities." ) y = None weight, variable, factor, ftv, domain_mask, n_edges = self._compile( V, cardinality, L, L_offset, y, deps, self.weights, self.dep_weights) fg = NumbSkull(n_inference_epoch=epochs, n_learning_epoch=0, stepsize=step_size, decay=decay, quiet=(not verbose), verbose=verbose, learn_non_evidence=True, burn_in=burn_in, sample_evidence=False) fg.loadFactorGraph(weight, variable, factor, ftv, domain_mask, n_edges) fg.inference(out=False) marginals = fg.factorGraphs[0].marginals[:V.shape[0]] return marginals
nthreads=1) subgraph = weight, variable, factor, fmap, domain_mask, edges ns_learing.loadFactorGraph(*subgraph) # 因子图参数学习 ns_learing.learning() # 因子图推理 # 参数学习完成后将weight的isfixed属性置为true for index, w in enumerate(weight): w["isFixed"] = True w['initialValue'] = ns_learing.factorGraphs[0].weight[index][ 'initialValue'] ns_inference = NumbSkull( n_inference_epoch=1000, n_learning_epoch=1000, stepsize=0.001, decay=0.95, reg_param=1e-6, regularization=2, truncation=10, quiet=(not False), verbose=False, learn_non_evidence=False, # need to test sample_evidence=False, burn_in=10, nthreads=1) ns_inference.loadFactorGraph(*subgraph) # 因子图推理 ns_inference.inference() #获取变量推理结果 print(ns_inference.factorGraphs[0].marginals)