def get_matches(self, x, epsilon=0.0):

        # print("GETTING MATCHES")
        # pprint(self.learner.get_hset())
        grounded = self.ground_example(x)
        # grounded = [(ground(a), x[a]) for a in x if (isinstance(a, tuple))]
        # print("FACTS")

        # pprint(grounded)

        index = build_index(grounded)

        # print("INDEX")
        # pprint(index)

        # print("OPERATOR")
        # pprint(self.operator)
        # print(self.learner.get_hset())

        for h in self.learner.get_hset():
            # print('h', h)
            # Update to include initial args
            operator = Operator(tuple(('Rule', ) + self.args), h, [])
            # print("OPERATOR", h)

            for m in operator.match(index, epsilon=epsilon):
                # print('match', m,
                result = tuple(
                    [unground(subst(m, ele)) for ele in[1:]])
                # result = tuple(['?' + subst(m, ele)
                #                 for ele in[1:]])
                # result = tuple(['?' + m[e] for e in self.target_types])
                # print('GET MATCHES T', result)
                yield result
    def check_match(self, t, x):
        # print("CHECK MATCHES T", t)

        t = tuple(ground(ele) for ele in t)

        # Update to include initial args
        mapping = {a: t[i] for i, a in enumerate(self.args)}

        # print("MY MAPPING", mapping)

        # print("CHECKING MATCHES")

        if t not in self.tuples:
            return False

        grounded = self.ground_example(x)
        # grounded = [(ground(a), x[a]) for a in x if (isinstance(a, tuple))]
        # pprint(grounded)
        # pprint(mapping)
        index = build_index(grounded)

        # Update to include initial args
        operator = Operator(tuple(('Rule', ) + self.args),
                            frozenset().union(self.constraints), [])
        for m in operator.match(index, initial_mapping=mapping):
            return True
        return False
def covers(h, x, initial_mapping):
    Returns true if h covers x
    index = build_index(x)
    operator = Operator(tuple(['Rule']), h, [])
    for m in operator.match(index, initial_mapping=initial_mapping):
        return True
    return False
    def get_matches(self, x, epsilon=0.0):
        x = x.get_view("flat_ungrounded")

        grounded = self.ground_example(x)

        index = build_index(grounded)

        for t in self.tuples:
            mapping = {a: t[i] for i, a in enumerate(self.args)}
            operator = Operator(tuple(('Rule', ) + self.args),
                                frozenset().union(self.constraints), [])

            for m in operator.match(index,
                result = tuple(ele.replace("QM", '?') for ele in t)
                yield result
    def is_specialization(self, s, h):
        Takes two hypotheses s and g and returns True if s is a specialization
        of h. Note, it returns False if s and h are equal (s is not a
        specialization in this case).
        if s == h:
            return False

        # remove vars, so the unification isn't going in both directions.
        s = set(remove_vars(l) for l in s)

        # check if h matches s (then s specializes h)
        index = build_index(s)
        operator = Operator(tuple(['Rule']), h, [])
        for m in operator.match(index):
            return True
        return False
    def check_match(self, t, x):
        # print("CHECK MATCHES T", t)

        t = tuple(ele.replace('?', "QM") for ele in t)

        # Update to include initial args
        mapping = {a: t[i] for i, a in enumerate(self.args)}

        # print("MY MAPPING", mapping)

        # print("CHECKING MATCHES")
        grounded = self.ground_example(x)
        # grounded = [(ground(a), x[a]) for a in x if (isinstance(a, tuple))]
        # pprint(grounded)
        # pprint(mapping)
        index = build_index(grounded)

        for h in self.learner.get_hset():
            # Update to include initial args
            operator = Operator(tuple(('Rule', ) + self.args), h, [])
            for m in operator.match(index, initial_mapping=mapping):
                return True
        return False
    def ifit(self, t, x, y):
        # print("IFIT T", t)
        # if y == 0:
        #     return

        x = {
            a: x[a]
            for a in x
            if (isinstance(a, tuple) and a[0] not in self.remove_attrs) or (
                not isinstance(a, tuple) and a not in self.remove_attrs)

        # x = {a: x[a] for a in x if self.is_structural_feature(a, x[a])}
        # x = {a: x[a] for a in x}

        # eles = set([field for field in t])
        # prior_count = 0
        # while len(eles) - prior_count > 0:
        #     prior_count = len(eles)
        #     for a in x:
        #         if isinstance(a, tuple) and a[0] == 'haselement':
        #             if a[2] in eles:
        #                 eles.add(a[1])
        #             # if self.matches(eles, a):
        #             #     names = get_attribute_components(a)
        #             #     eles.update(names)

        # x = {a: x[a] for a in x
        #      if self.matches(eles, a)}

        # foa_mapping = {field: 'foa%s' % j for j, field in enumerate(t)}
        foa_mapping = {}
        for j, field in enumerate(t):
            if field not in foa_mapping:
                foa_mapping[field] = 'foa%s' % j

        # for j,field in enumerate(t):
        #     x[('foa%s' % j, field)] = True
        x = rename_flat(x, foa_mapping)
        # pprint(x)

        # print("adding:")

        ns = NameStandardizer()
        sm = StructureMapper(self.concept)
        x = sm.transform(ns.transform(x))
        # pprint(x)

        if y == 1:

        # print()
        # print('POSITIVE')
        # pprint(self.pos_concept.av_counts)
        # print('NEGATIVE')
        # pprint(self.neg_concept.av_counts)

        # pprint(self.concept.av_counts)

        pos_instance = {}
        pos_args = set()
        for attr in self.pos_concept.av_counts:
            attr_count = 0
            for val in self.pos_concept.av_counts[attr]:
                attr_count += self.pos_concept.av_counts[attr][val]
            if attr_count == self.pos_concept.count:
                if len(self.pos_concept.av_counts[attr]) == 1:
                    args = get_vars(attr)
                    pos_instance[attr] = val
                    args = get_vars(attr)
                    val_gensym = value_gensym()
                    pos_instance[attr] = val_gensym

            # if len(self.pos_concept.av_counts[attr]) == 1:
            #     for val in self.pos_concept.av_counts[attr]:
            #         if ((self.pos_concept.av_counts[attr][val] ==
            #              self.pos_concept.count)):
            #             args = get_vars(attr)
            #             pos_args.update(args)
            #             pos_instance[attr] = val

        # print('POS ARGS', pos_args)

        neg_instance = {}
        for attr in self.neg_concept.av_counts:
            # print("ATTR", attr)
            args = set(get_vars(attr))
            if not args.issubset(pos_args):

            for val in self.neg_concept.av_counts[attr]:
                # print("VAL", val)
                if ((attr not in self.pos_concept.av_counts
                     or val not in self.pos_concept.av_counts[attr])):
                    neg_instance[attr] = val

        foa_mapping = {'foa%s' % j: '?foa%s' % j for j in range(len(t))}
        pos_instance = rename_flat(pos_instance, foa_mapping)
        neg_instance = rename_flat(neg_instance, foa_mapping)

        conditions = ([(a, pos_instance[a])
                       for a in pos_instance] + [('not', (a, neg_instance[a]))
                                                 for a in neg_instance])

        # print("========CONDITIONS======")
        # pprint(conditions)
        # print("========CONDITIONS======")

        self.target_types = ['?foa%s' % i for i in range(len(t))]
        self.operator = Operator(tuple(['Rule'] + self.target_types),
                                 conditions, [])
def optimize_clause(h, constraints, seed, pset, nset):
    Returns the set of most specific generalization of h that do NOT
    cover x.
    c_length = clause_length(h)
    print('# POS', len(pset))
    print('# NEG', len(nset))
    print('length', c_length)

    p_covered, n_covered = test_coverage(h, constraints, pset, nset)
    p_uncovered = [p for p in pset if p not in p_covered]
    n_uncovered = [n for n in nset if n not in n_covered]
    print("P COVERED", len(p_covered))
    print("N COVERED", len(n_covered))
    initial_score = clause_score(clause_accuracy_weight, len(p_covered),
                                 len(p_uncovered), len(n_covered),
                                 len(n_uncovered), c_length)

    p, pm = seed
    pos_partial = list(compute_bottom_clause(p, pm))
    # print('POS PARTIAL', pos_partial)

    # TODO if we wanted we could add the introduction of new variables to the
    # get_variablizations function.
    possible_literals = {}
    for i, l in enumerate(pos_partial):
        possible_literals[i] = [None, l] + [v for v in get_variablizations(l)]
    partial_literals = set(
        [l for i in possible_literals for l in possible_literals[i]])

    additional_literals = h - partial_literals

    if len(additional_literals) > 0:
        p_index = build_index(p)
        operator = Operator(tuple(('Rule', )), h.union(constraints), [])
        for add_m in operator.match(p_index, initial_mapping=pm):
        additional_lit_mapping = {
            rename(add_m, l): l
            for l in additional_literals
        for l in additional_lit_mapping:
            new_l = additional_lit_mapping[l]
            reverse_m = {pm[a]: a for a in pm}
            l = rename(reverse_m, l)
            if l not in pos_partial:
                print("ERROR l not in pos_partial")
                import time

    # pprint(possible_literals)
    reverse_pl = {
        l: (i, j)
        for i in possible_literals for j, l in enumerate(possible_literals[i])

    clause_vector = [0 for i in range(len(possible_literals))]
    for l in h:
        if l not in reverse_pl:
            print("MISSING LITERAL!!!")
            import time

        i, j = reverse_pl[l]
        clause_vector[i] = j
    clause_vector = tuple(clause_vector)

    # print(clause_vector)

    flip_weights = [(len(possible_literals[i]) - 1, i)
                    for i in possible_literals]
    size = 1
    for w, _ in flip_weights:
        size *= (w + 1)
    print("SIZE OF SEARCH SPACE:", size)

    num_successors = sum([w for w, c in flip_weights])
    print("NUM SUCCESSORS", num_successors)
    temp_length = num_successors
    temp_length = 10
    # initial_temp = 0.15
    # initial_temp = 0.0
    print("TEMP LENGTH", temp_length)
    print('INITIAL SCORE', initial_score)
    problem = ClauseOptimizationProblem(clause_vector,
                                        initial_cost=-1 * initial_score,
                                        extra=(possible_literals, flip_weights,
                                               constraints, pset, nset,
    # for sol in hill_climbing(problem):
    for sol in simulated_annealing(
            # initial_temp=initial_temp,
        # print("SOLUTION FOUND", sol.state)
        print('FINAL SCORE', -1 * sol.cost())
        return build_clause(sol.state, possible_literals)
				  	(<func>, '?<value_1>', ...)
			     ), ...]
			example :[(('value', ('Add', ('value', '?x'), ('value', '?y'))),
	                     (int_float_add, '?xv', '?yv'))])
	Full Example: 
	def int_float_add(x, y):
	    z = float(x) + float(y)
	    if z.is_integer():
	        z = int(z)
	    return str(z)
	add_rule = Operator(('Add', '?x', '?y'),
			            [(('value', '?x'), '?xv'),
			             (('value', '?y'), '?yv'),
			             (lambda x, y: x <= y, '?x', '?y')
			            [(('value', ('Add', ('value', '?x'), ('value', '?y'))),
			              (int_float_add, '?xv', '?yv'))])
	Note: You should explicitly register your operators so you can
			 refer to them in your training.json, otherwise the name will
			 be the same as the local variable 
			example: Operator.register("Add")

vvvvvvvvvvvvvvvvvvvv WRITE YOUR OPERATORS BELOW vvvvvvvvvvvvvvvvvvvvvvv '''

# ^^^^^^^^^^^^^^ DEFINE ALL YOUR OPERATORS ABOVE THIS LINE ^^^^^^^^^^^^^^^^
for name, op in locals().copy().items():
    if (isinstance(op, Operator)):
        Operator.register(name, op)
from pprint import pprint
from tabulate import tabulate
import argparse

from planners.fo_planner import Operator
from agents.RLAgent import RLAgent
from agents.ModularAgent import ModularAgent
from agents.Memo import Memo

ttt_horizontal_adj = Operator(('horizontal_adj', '?s1', '?s2'),
                              [(('row', '?s1'), '?s1r'),
                               (('row', '?s2'), '?s1r'),
                               (('col', '?s1'), '?s1c'),
                               (('col', '?s2'), '?s2c'),
                               (lambda x, y: abs(x - y) == 1, '?s1c', '?s2c')],
                              [(('horizontal_adj', '?s1', '?s2'), True)])

ttt_vertical_adj = Operator(
    ('vertical_adj', '?s1', '?s2'),
    [(('row', '?s1'), '?s1r'), (('row', '?s2'), '?s2r'),
     (('col', '?s1'), '?s1c'), (('col', '?s2'), '?s1c'),
     (lambda x, y: abs(x - y) == 1, '?s1r', '?s2r')],
    [(('vertical_adj', '?s1', '?s2'), True)])

ttt_diag_adj = Operator(('diag_adj', '?s1', '?s2'),
                        [(('row', '?s1'), '?s1r'), (('row', '?s2'), '?s2r'),
                         (('col', '?s1'), '?s1c'), (('col', '?s2'), '?s2c'),
                         (lambda x, y: abs(x - y) == 1, '?s1r', '?s2r'),
                         (lambda x, y: abs(x - y) == 1, '?s1c', '?s2c')],
                        [(('diag_adj', '?s1', '?s2'), True)])
  z = z % 10
  if z.is_integer():
    z = int(z)
  return str(z)

def int3_float_add_then_tens(x, y, w):
  z = float(x) + float(y) + float(w)
  z = z // 10
  if z.is_integer():
    z = int(z)
  return str(z)

add_rule = Operator(('Add', '?x', '?y'),
                    [(('value', '?x'), '?xv'),
                     (('value', '?y'), '?yv'),
                     # (lambda x, y: x <= y, '?x', '?y')
                    [(('value', ('Add', ('value', '?x'), ('value', '?y'))),
                      (int_float_add, '?xv', '?yv'))])
Operator.register("add", add_rule)

add_then_ones = Operator(('Add_Then_Ones', '?x', '?y'),
                    [(('value', '?x'), '?xv'),
                     (('value', '?y'), '?yv'),
                     # (lambda x, y: x <= y, '?x', '?y')
                    [(('value', ('Add_Then_Ones', ('value', '?x'), ('value', '?y'))),
                      (int2_float_add_then_ones, '?xv', '?yv'))])

add_then_tens = Operator(('Add_Then_Tens', '?x', '?y'),
                    [(('value', '?x'), '?xv'),
    def successors(self, node):
        h = node.state
        # print("EXPANDING H", h)
        args, constraints, pset, neg, neg_mapping, gensym = node.extra

        all_args = set(s for x in h.union(constraints)
                       for s in extract_strings(x) if is_variable(s))

        if len(pset) == 0:

        p, pm = choice(pset)
        p_index = build_index(p)

        operator = Operator(tuple(('Rule', ) + tuple(all_args)),
                            h.union(constraints), [])

        # operator = Operator(tuple(('Rule',) + args), h, [])

        found = False
        for m in operator.match(p_index, initial_mapping=pm):
            reverse_m = {m[a]: a for a in m}
            pos_partial = set([rename(reverse_m, x) for x in p])
            found = True

        if not found:

        n_index = build_index(neg)
        found = False
        for nm in operator.match(n_index, initial_mapping=neg_mapping):
            # print(nm)
            reverse_nm = {nm[a]: a for a in nm}
            neg_partial = set([rename(reverse_nm, x) for x in neg])
            found = True

        if not found:

        unique_pos = pos_partial - neg_partial
        unique_neg = neg_partial - pos_partial

        # print("UNIQUE POS", unique_pos)
        # print("UNIQUE NEG", unique_neg)

        # Yield all minimum specializations of current vars
        for a in m:
            # TODO make sure m[a] is a minimum specialization
            sub_m = {a: m[a]}
            new_h = frozenset([subst(sub_m, ele) for ele in h])
            # print("SPECIALIZATION", new_h, sub_m)
            # print()
            yield Node(new_h, node, ('specializing', (a, m[a])),
                       node.cost() + 1, node.extra)

        # Add Negations for all neg specializations
        # for a in nm:
        #     sub_nm = {a: nm[a]}
        #     new_nh = set()
        #     for ele in h:
        #         new = subst(sub_nm, ele)
        #         if new != ele and new not in h:
        #             new_nh.add(('not', new))
        #     new_h = h.union(new_nh)
        #     print("NEGATION SPECIALIZATION", new_nh)
        #     yield Node(new_h, node, ('negation specialization', (a, nm[a])),
        #                node.cost()+1, node.extra)

        # if current vars then add all relations that include current vars
        if len(all_args) > 0:
            added = set()
            for literal in unique_pos:
                if literal in h or literal in constraints:
                args = set(s for s in extract_strings(literal)
                           if is_variable(s))
                if len(args.intersection(all_args)) > 0:
                    key = (literal[0], ) + tuple(
                        ele if is_variable(ele) else '?'
                        for ele in literal[1:])
                    if key in added:

                    literal = generalize_literal(literal, gensym)
                    new_h = h.union(frozenset([literal]))
                    # print("ADD CURRENT", new_h)
                    # print()
                    yield Node(new_h, node, ('adding current', literal),
                               node.cost() + 1, node.extra)

            added = set()
            for literal in unique_pos:
                if literal in h or literal in constraints:
                if literal[0] in added:
                literal = generalize_literal(literal, gensym)
                new_h = h.union(frozenset([literal]))
                # print("ADD NEW", new_h)
                # print()
                yield Node(new_h, node, ('adding', literal),
                           node.cost() + 1, node.extra)
from pprint import pprint
from tabulate import tabulate
import argparse

from planners.fo_planner import Operator
from agents.RLAgent import RLAgent
from agents.WhereWhenHowNoFoa import WhereWhenHowNoFoa
from agents.Memo import Memo

ttt_available = Operator(('available', '?s'),
                         [(('value', '?s'), '?sv'),
                          (('row', '?s'), '?sr'),
                          (('col', '?s'), '?sc'),
                          (lambda x: x > 0, '?sr'),
                          (lambda x: x > 0, '?sc')],
                         [(('available', '?s'), (lambda x: x == "", '?sv'))])

ttt_horizontal_adj = Operator(('horizontal_adj', '?s1', '?s2'),
                              [(('row', '?s1'), '?s1r'),
                               (('row', '?s2'), '?s1r'),
                               (('col', '?s1'), '?s1c'),
                               (('col', '?s2'), '?s2c'),
                               (lambda x, y: abs(x-y) == 1, '?s1c', '?s2c')],
                              [(('horizontal_adj', '?s1', '?s2'), True)])

ttt_vertical_adj = Operator(('vertical_adj', '?s1', '?s2'),
                            [(('row', '?s1'), '?s1r'),
                             (('row', '?s2'), '?s2r'),
                             (('col', '?s1'), '?s1c'),
                             (('col', '?s2'), '?s1c'),