def inference_ad3(unary_potentials, pairwise_potentials, edges, relaxed=False, verbose=0, return_energy=False): import AD3 shape_org = unary_potentials.shape[:-1] n_states, pairwise_potentials = \ _validate_params(unary_potentials, pairwise_potentials, edges) unaries = unary_potentials.reshape(-1, n_states) res = AD3.general_graph(unaries, edges, pairwise_potentials, verbose=verbose) unary_marginals, pairwise_marginals, energy = res #n_fractional = np.sum(unary_marginals.max(axis=-1) < .99) #if n_fractional: #print("fractional solutions found: %d" % n_fractional) if relaxed: unary_marginals = unary_marginals.reshape(unary_potentials.shape) y = (unary_marginals, pairwise_marginals) else: y = np.argmax(unary_marginals, axis=-1) y = y.reshape(shape_org) if return_energy: return y, -energy return y
def _inference_ad3(x, unary_params, pairwise_params, edges, relaxed=False, verbose=0): raise NotImplementedError("AD3 doesn't work on graphs yet!") res = AD3.simple_grid(unary_params * x, pairwise_params, verbose=verbose) unary_marginals, pairwise_marginals, energy = res n_fractional = np.sum(unary_marginals.max(axis=-1) < .99) if n_fractional: print("fractional solutions found: %d" % n_fractional) if relaxed: unary_marginals = unary_marginals.reshape(x.shape) pairwise_accumulated = pairwise_marginals.sum(axis=0) pairwise_accumulated = pairwise_accumulated.reshape(x.shape[-1], x.shape[-1]) y = (unary_marginals, pairwise_accumulated) else: y = np.argmax(unary_marginals, axis=-1) y = y.reshape(x.shape[0], x.shape[1]) return y