def axioms(features, labels, training=False): x_A = ltn.Variable("x_A",features[labels==0]) x_B = ltn.Variable("x_B",features[labels==1]) x_C = ltn.Variable("x_C",features[labels==2]) axioms = [ Forall(x_A,p([x_A,class_A],training=training)), Forall(x_B,p([x_B,class_B],training=training)), Forall(x_C,p([x_C,class_C],training=training)) ] sat_level = formula_aggregator(axioms).tensor return sat_level
def axioms(images_x, images_y, labels_z, p_schedule=tf.constant(2.)): images_x = ltn.Variable("x", images_x) images_y = ltn.Variable("y", images_y) labels_z = ltn.Variable("z", labels_z) axiom = Forall( ltn.diag(images_x,images_y,labels_z), Exists( (d1,d2), And(Digit([images_x,d1]),Digit([images_y,d2])), mask=equals([add([d1,d2]), labels_z]), p=p_schedule ), p=2 ) sat = axiom.tensor return sat
def axioms(features, labels_sex, labels_color): x = ltn.Variable("x", features) x_blue = ltn.Variable("x_blue", features[labels_color == "B"]) x_orange = ltn.Variable("x_orange", features[labels_color == "O"]) x_male = ltn.Variable("x_blue", features[labels_sex == "M"]) x_female = ltn.Variable("x_blue", features[labels_sex == "F"]) axioms = [ Forall(x_blue, p([x_blue, class_blue])), Forall(x_orange, p([x_orange, class_orange])), Forall(x_male, p([x_male, class_male])), Forall(x_female, p([x_female, class_female])), Forall(x, Not(And(p([x, class_blue]), p([x, class_orange])))), Forall(x, Not(And(p([x, class_male]), p([x, class_female])))) ] sat_level = formula_aggregator(axioms).tensor return sat_level
def axioms(images_x1, images_x2, images_y1, images_y2, labels_z, p_schedule): images_x1 = ltn.Variable("x1", images_x1) images_x2 = ltn.Variable("x2", images_x2) images_y1 = ltn.Variable("y1", images_y1) images_y2 = ltn.Variable("y2", images_y2) labels_z = ltn.Variable("z", labels_z) axiom = Forall( ltn.diag(images_x1, images_x2, images_y1, images_y2, labels_z), Exists( (d1, d2, d3, d4), And(And(Digit([images_x1, d1]), Digit([images_x2, d2])), And(Digit([images_y1, d3]), Digit([images_y2, d4]))), mask=equals([ labels_z, add([two_digit_number([d1, d2]), two_digit_number([d3, d4])]) ]), p=p_schedule), p=2) sat = axiom.tensor return sat
def axioms(data, labels): x_A = ltn.Variable("x_A", data[labels]) x_not_A = ltn.Variable("x_not_A", data[tf.logical_not(labels)]) axioms = [Forall(x_A, A(x_A)), Forall(x_not_A, Not(A(x_not_A)))] sat_level = formula_aggregator(axioms).tensor return sat_level
def axioms(x_data, y_data): x = ltn.Variable("x", x_data) y = ltn.Variable("y", y_data) return Forall(ltn.diag(x,y), eq([f(x),y]))
""" DATASET """ ds_train, ds_test = data.get_mnist_op_dataset( count_train=n_examples_train, count_test=n_examples_test, buffer_size=10000, batch_size=batch_size, n_operands=2, op=lambda args: args[0]+args[1]) """ LTN MODEL AND LOSS """ ### Predicates logits_model = baselines.SingleDigit() Digit = ltn.Predicate(ltn.utils.LogitsToPredicateModel(logits_model)) ### Variables d1 = ltn.Variable("digits1", range(10)) d2 = ltn.Variable("digits2", range(10)) ### Operators Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std()) And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod()) Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum()) Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach()) Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(),semantics="forall") Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(),semantics="exists") # mask add = ltn.Function.Lambda(lambda inputs: inputs[0]+inputs[1]) equals = ltn.Predicate.Lambda(lambda inputs: inputs[0] == inputs[1]) ### Axioms
def sat_phi3(features): x = ltn.Variable("x", features) phi3 = Forall(x, Implies(p([x, class_blue]), p([x, class_male])), p=5) return phi3.tensor
def sat_phi2(features): x = ltn.Variable("x", features) phi2 = Forall(x, Implies(p([x, class_blue]), p([x, class_orange])), p=5) return phi2.tensor
p_type = args["p"] """ DATASET """ ds_train, ds_test = data.get_mnist_op_dataset( count_train=n_examples_train, count_test=n_examples_test, buffer_size=10000, batch_size=batch_size, n_operands=4, op=lambda args: 10 * args[0] + args[1] + 10 * args[2] + args[3]) """ LTN MODEL AND LOSS """ ### Predicates logits_model = baselines.SingleDigit() Digit = ltn.Predicate(ltn.utils.LogitsToPredicateModel(logits_model)) ### Variables d1 = ltn.Variable("digits1", range(10)) d2 = ltn.Variable("digits2", range(10)) d3 = ltn.Variable("digits3", range(10)) d4 = ltn.Variable("digits4", range(10)) ### Operators Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std()) And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod()) Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum()) Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach()) Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(), semantics="forall") Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(), semantics="exists") # mask add = ltn.Function.Lambda(lambda inputs: inputs[0] + inputs[1])