Example #1
0
    def __init__(self,
                 data,
                 graph,
                 config,
                 attr_encoder,
                 is_uncertain=cmd_args.prob_dataset):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.config = config
        self.attr_encoder = attr_encoder

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True

        if is_uncertain:
            self.interp = SceneInterp(graph.scene, config, is_uncertain=True)
        else:
            self.interp = SceneInterp(graph.scene["ground_truth"],
                                      config,
                                      is_uncertain=False)

        self.state = self.interp.get_init_state()
        self.is_uncertain = is_uncertain
    def get_option_num(self, clauses, scene, target):
        scene_interp = SceneInterp(scene, self.config)
        binding_dict = scene_interp.fast_query(clauses)

        success = check_success(binding_dict, target)
        possible = check_possible(binding_dict, target)
        option_num = len(binding_dict["var_0"]) if "var_0" in binding_dict.keys() else len(scene["objects"])
        return option_num, success, possible
Example #3
0
def eu_solve_prog(scene, config, target, max_depth):
    interpreter = SceneInterp(scene, config)
    clean_state = interpreter.get_init_state()
    prog = eu_synthesize(interpreter.interpreter,
                         target,
                         clean_state,
                         is_uncertain=False,
                         max_depth=max_depth)
    return prog
Example #4
0
    def __init__(self, data, graph, config, attr_encoder):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.config = config
        self.attr_encoder = attr_encoder

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True
        self.state = None
        self.interp = SceneInterp(graph.scene, config)
        self.interp_state = self.interp.get_init_state()
Example #5
0
def add_attr(prog, scene):
    scene_interp = SceneInterp(scene, config)
    state = scene_interp.get_state(prog)

    bd_list = scene_interp.get_valid_binding(state)
    new_clauses = []
    ref = get_ref(prog)

    for var, obj in enumerate(bd_list):
        if var not in ref:
            continue
        color = scene["objects"][obj]["color"]
        shape = scene["objects"][obj]["shape"]
        new_clauses.append(["color", color, var])
        new_clauses.append(["shape", shape, var])

    prog += new_clauses
    return prog
    def __init__(self,
                 data,
                 graph,
                 config,
                 attr_encoder,
                 state=None,
                 ref=False,
                 is_uncertain=False):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.idx_selected = []

        self.config = config
        self.attr_encoder = attr_encoder
        self.actions = get_all_clauses(config)
        self.create_action_dict()
        self.ref_flag = ref
        if ref:
            self.ref = [0]
        else:
            self.ref = list(range(cmd_args.max_var_num))

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True
        self.unreachable = []
        self.reachable_dict, self.unreachable_dict = get_reachable_dict(
            self.actions)
        self.is_uncertain = is_uncertain
        if is_uncertain:
            self.interp = SceneInterp(graph.scene, config, is_uncertain=True)
        else:
            self.interp = SceneInterp(graph.scene["ground_truth"],
                                      config,
                                      is_uncertain=False)

        if type(state) == type(None):
            self.state = self.interp.get_init_state()
        else:
            self.state = state
Example #7
0
def shrink(prog, scene, config):
    scene_interp = SceneInterp(scene, config)
    last_prog = prog

    while (True):
        prog = remove_one(scene_interp, prog)

        # is minimal
        if type(prog) == type(None):
            return last_prog
        last_prog = prog
Example #8
0
def cont_single(file_name,
                datapoint,
                model,
                graphs,
                save_dir,
                attr_encoder,
                config,
                save=True,
                n=10):

    file_path = os.path.join(save_dir, str(file_name))
    exist = os.path.exists(file_path)
    logging.info(f"task: {file_name}, {exist}")
    if exist:
        logging.info("skip")
        return

    total_ct = 0
    memory = ReplayMemory(10000)
    method = None

    policy = copy.deepcopy(model.policy)
    target = copy.deepcopy(policy)
    target.eval()

    optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr)

    success_progs = []
    all_progs = []

    eps = cmd_args.eps
    graph = graphs[datapoint.graph_id]

    for it in range(cmd_args.test_iter):

        suc, prog, obj_left = episode(policy, target, datapoint, graph, config,
                                      attr_encoder, memory, total_ct,
                                      optimizer)
        total_ct += 1

        if suc:
            success_progs.append(prog)
            method = "model"
            break

        if not suc:
            all_progs.append((prog, obj_left))

    logging.info(f"success: {success_progs}")

    # if not success, explot the best selections.
    scene_interpreter = SceneInterp(graph.ground_truth_scene, config)

    if len(success_progs) == 0:
        logging.info("start exploit")

        # take second element for sort
        candidates = sorted(all_progs, key=lambda x: x[1])
        for prog, obj_left in candidates:
            # no progress
            if len(prog) == 0:
                logging.info("empty prog to exploit, stop")
                break

            logging.info(f"prog: {prog}")
            logging.info(f"obj left: {obj_left}")
            prog_left = scene_interpreter.interp_enum(prog,
                                                      datapoint.y,
                                                      max_depth=2)

            if not type(prog_left) == type(None):
                print(prog_left)
                prog.append(prog_left)
                success_progs.append(prog)
                method = f"exploit: {len(prog_left)}"
                logging.info(
                    f"We found: {success_progs} at depth {len(prog_left)}")
                break

    if save:
        file_path = os.path.join(save_dir, str(file_name))
        with open(file_path, 'w') as suc_prog_file:
            res = {}
            res['prog'] = success_progs
            res['method'] = method
            json.dump(res, suc_prog_file)

    return success_progs
Example #9
0
class Env():
    def __init__(self,
                 data,
                 graph,
                 config,
                 attr_encoder,
                 is_uncertain=cmd_args.prob_dataset):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.config = config
        self.attr_encoder = attr_encoder

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True

        if is_uncertain:
            self.interp = SceneInterp(graph.scene, config, is_uncertain=True)
        else:
            self.interp = SceneInterp(graph.scene["ground_truth"],
                                      config,
                                      is_uncertain=False)

        self.state = self.interp.get_init_state()
        self.is_uncertain = is_uncertain

    # TODO: check the effect of update data on the whole dataset
    def update_data(self, binding_dict):
        self.graph.update_binding(binding_dict)
        # self.data.update_data(self.graph, self.attr_encoder)

        x = self.attr_encoder.get_embedding(
            [node.name for node in self.graph.nodes])
        edge_index, edge_types = self.graph.get_edge_info()
        # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types))
        edge_attr = torch.tensor(edge_types, dtype=torch.float)
        edge_index = torch.tensor(edge_index)
        x = torch.tensor(x)
        batch = self.data.batch

        # print(f"previous edge num: {len(self.data.edge_attr)}")
        self.data = Data(x=x,
                         y=self.data.y,
                         edge_index=edge_index,
                         edge_attr=edge_attr)
        self.data.batch = batch
        # print(f"previous edge num: {len(self.data.edge_attr)}")

    def check_success(self, binding_dict):
        # print(binding_dict)
        if not self.obj_poss_left[-1] == 1:
            return

        if "var_0" not in binding_dict.keys():
            return

        if binding_dict["var_0"][0] == str(int(self.data.y)):
            # print("success!")
            self.success = True

    def check_possible(self, binding_dict):
        if self.success:
            return

        if "var_0" not in binding_dict.keys():
            return

        if "var_0" in binding_dict.keys() and str(int(
                self.data.y)) in binding_dict["var_0"]:
            return

        self.possible = False

    def step(self):
        if len(self.clauses) == 0:
            return torch.tensor(0, dtype=torch.float32)

        binding_dict, new_state = self.interp.state_query(
            self.state, self.clauses[-1])
        self.state = new_state

        # update the success and possible fields
        self.check_possible(binding_dict)

        if not self.possible:
            self.obj_poss_left.append(0)
        elif not "var_0" in binding_dict.keys():
            self.obj_poss_left.append(self.obj_nums)
            self.update_data(binding_dict)
        else:
            self.obj_poss_left.append(len(binding_dict["var_0"]))
            self.update_data(binding_dict)

        self.check_success(binding_dict)

        # TODO: the reward function need to be updated
        reward = get_reward(self)

        # print(reward)
        return torch.tensor(reward, dtype=torch.float32)

    def is_finished(self):
        # edge case handling
        if len(self.obj_poss_left) == 1:
            return False

        # no possibilities
        if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1:
            return True

        # already succeed
        if self.success:
            return True

        # not possible
        if not self.possible:
            return True

        return False
def cont_single(file_name,
                datapoint,
                policy,
                graphs,
                save_dir,
                attr_encoder,
                config,
                save=True,
                n=10):

    policy = deepcopy(policy)
    policy.train()
    optimizer = torch.optim.Adam(policy.parameters(), lr=cmd_args.lr)

    success_progs = []
    all_progs = []

    global sub_ct
    eps = cmd_args.eps
    graph = graphs[datapoint.graph_id]

    for it in range(cmd_args.episode_iter):
        sub_ct = 0
        env, loss = fit_one(policy, datapoint, graph, eps, attr_encoder,
                            config)
        if env.success:
            success_progs.append(env.clauses)

        if not env.possible:
            all_progs.append((env.clauses[:-1], env.obj_poss_left[-2]))

        if it % cmd_args.batch_size == 0:
            optimizer.zero_grad()

        loss.backward()

        if it % cmd_args.batch_size == 0:
            optimizer.step()

        if len(success_progs) > 0:
            break

    if save:
        file_path = os.path.join(save_dir, str(file_name))
        with open(file_path, 'w') as suc_prog_file:
            json.dump(success_progs, suc_prog_file)

    logging.info(f"success: {success_progs}")
    # print(f"all: {all_progs}")
    # if not success, explot the best selections.
    scene_interpreter = SceneInterp(graph.scene, config)

    if len(success_progs) == 0:
        logging.info("start exploit")
        # take second element for sort
        candidates = sorted(all_progs, key=lambda x: x[1])
        for prog, obj_left in candidates:
            # no progress
            if len(prog) == 0:
                logging.info("empty prog to exploit, stop")
                break

            logging.info(f"prog: {prog}")
            logging.info(f"obj left: {obj_left}")
            prog_left = scene_interpreter.interp_enum(prog, datapoint.y)
            if not type(prog_left) == type(None):
                print(prog_left)
                prog.append(prog_left)
                success_progs.append(prog)
                logging.info(f"We found: {success_progs}")
                break

    return success_progs
Example #11
0
class Env():
    def __init__(self, data, graph, config, attr_encoder):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.config = config
        self.attr_encoder = attr_encoder

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True
        self.state = None
        self.interp = SceneInterp(graph.scene, config)
        self.interp_state = self.interp.get_init_state()

    def check_success(self, binding_dict):
        if not self.obj_poss_left[-1] == 1:
            return

        if "var_0" not in binding_dict.keys():
            return

        if list(binding_dict["var_0"])[0] == int(self.data.y):
            # print("success!")
            self.success = True

    def check_possible(self, binding_dict):
        if self.success:
            return

        if "var_0" not in binding_dict.keys():
            return

        if "var_0" in binding_dict.keys() and int(
                self.data.y) in binding_dict["var_0"]:
            return

        self.possible = False

    def step(self):
        if len(self.clauses) == 0:
            return torch.tensor(0, dtype=torch.float32)

        # binding_dict = query(self.graph.scene, self.clauses, self.config)
        next_clause = self.clauses[-1]
        useful = self.interp.useful_check(self.interp_state, next_clause)
        logging.info(f"useful: {useful}")

        binding_dict, new_state = self.interp.state_query(
            self.interp_state, next_clause)
        self.interp_state = new_state

        # update the success and possible fields
        self.check_possible(binding_dict)

        if not self.possible:
            self.obj_poss_left.append(0)
        elif not "var_0" in binding_dict.keys():
            self.obj_poss_left.append(self.obj_nums)
        else:
            self.obj_poss_left.append(len(binding_dict["var_0"]))
        self.check_success(binding_dict)

        done = self.success or (not self.possible)

        if not useful:
            reward = torch.tensor(-1, dtype=torch.float32)
        else:
            reward = torch.tensor(get_reward(self), dtype=torch.float32)

        return reward

    def is_finished(self):
        # edge case handling
        if len(self.obj_poss_left) == 1:
            return False

        # # duplicate clauses
        # dupes = [x for n, x in enumerate(self.clauses) if x in self.clauses[:n]]
        # if len(dupes) > 0:
        #     return True

        # no possibilities
        if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1:
            return True

        # already succeed
        if self.success:
            return True

        # not possible
        if not self.possible:
            return True

        return False
def check_prog_correct(scene, prog, target, config):
    interp = SceneInterp(scene["ground_truth"], config, is_uncertain=False)
    binding_dict = interp.fast_query(prog)
    res = check_success(binding_dict, target)
    return res
class Env():
    def __init__(self,
                 data,
                 graph,
                 config,
                 attr_encoder,
                 state=None,
                 ref=False):

        self.data = deepcopy(data)
        self.graph = deepcopy(graph)
        self.obj_nums = len(graph.scene["objects"])
        self.clauses = []
        self.idx_selected = []

        self.config = config
        self.attr_encoder = attr_encoder
        self.actions = get_all_clauses(config)
        self.create_action_dict()
        self.ref_flag = ref
        if ref:
            self.ref = [0]
        else:
            self.ref = list(range(cmd_args.max_var_num))

        # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ]
        self.obj_poss_left = [self.obj_nums]
        self.success = False
        self.possible = True
        self.unreachable = []
        self.reachable_dict, self.unreachable_dict = get_reachable_dict(
            self.actions)

        self.interp = SceneInterp(graph.scene, config)
        if type(state) == type(None):
            self.state = self.interp.get_init_state()
        else:
            self.state = state

    # def reset(self, graph):
    #     self.update_data(graph, self.attr_encoder)

    def create_action_dict(self):
        self.action_dict = {}
        for action_id, action in enumerate(self.actions):
            self.action_dict[str(action)] = action_id

    # TODO: check the effect of update data on the whole dataset
    def update_data(self, binding_dict):
        self.graph.update_binding(binding_dict)
        # self.data.update_data(self.graph, self.attr_encoder)

        x = self.attr_encoder.get_embedding(self.graph.get_nodes())
        edge_index, edge_types = self.graph.get_edge_info()
        # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types))
        edge_attr = torch.tensor(edge_types)
        edge_index = torch.tensor(edge_index)
        x = torch.tensor(x)
        batch = torch.zeros(x.shape[0], dtype=torch.int64)

        # print(f"previous edge num: {len(self.data.edge_attr)}")
        self.data = Data(x=x,
                         y=self.data.y,
                         edge_index=edge_index,
                         edge_attr=edge_attr)
        self.data.batch = batch
        # print(f"previous edge num: {len(self.data.edge_attr)}")

    def check_success(self, binding_dict):
        if not self.obj_poss_left[-1] == 1:
            return

        if "var_0" not in binding_dict.keys():
            return

        if list(binding_dict["var_0"])[0] == int(self.data.y):
            # print("success!")
            self.success = True

    def check_possible(self, binding_dict):
        if self.success:
            return

        if "var_0" not in binding_dict.keys():
            return

        if "var_0" in binding_dict.keys() and int(
                self.data.y) in binding_dict["var_0"]:
            return

        self.possible = False

    def get_state(self):
        self.unreachable = self.unreachable_dict[str(sorted(self.ref))]
        unreachable = self.idx_selected + self.unreachable
        return State(self.actions, self.data, self.graph, self.state,
                     self.action_dict, unreachable)

    # not support batching yet
    def step(self, action_idx):

        next_clause = self.actions[action_idx]

        if self.ref_flag:
            for element in next_clause:
                if type(element) == int:
                    if element not in self.ref:
                        self.ref.append(element)

            self.unreachable = self.unreachable_dict[str(sorted(self.ref))]

        # if 'red' in next_clause or 'blue' in next_clause:
        #     print('here')

        self.idx_selected.append(action_idx)
        self.clauses.append(next_clause)

        useful = self.interp.useful_check(self.state, next_clause)
        logging.info(f"useful: {useful}")
        binding_dict, new_state = self.interp.state_query(
            self.state, next_clause)
        self.state = new_state

        # update the success and possible fields
        self.check_possible(binding_dict)

        if not self.possible:
            self.obj_poss_left.append(0)
        elif not "var_0" in binding_dict.keys():
            self.obj_poss_left.append(self.obj_nums)
            self.update_data(binding_dict)
        else:
            self.obj_poss_left.append(len(binding_dict["var_0"]))
            self.update_data(binding_dict)

        # update whether done or not
        self.check_success(binding_dict)
        done = self.success or (not self.possible)

        if not useful:
            reward = torch.tensor(-1, dtype=torch.float32)
        else:
            reward = torch.tensor(get_reward(self), dtype=torch.float32)

        self.update_data(binding_dict)
        info = {}

        logging.info(f"selected: {self.idx_selected}")
        logging.info(f"done: {done}")
        logging.info(f"success: {self.success}")

        state = self.get_state()

        return state, reward, done, info

    def is_finished(self):
        # edge case handling
        if len(self.obj_poss_left) == 1:
            return False

        # # duplicate clauses
        # dupes = [x for n, x in enumerate(self.clauses) if x in self.clauses[:n]]
        # if len(dupes) > 0:
        #     return True

        # no possibilities
        if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1:
            return True

        # already succeed
        if self.success:
            return True

        # not possible
        if not self.possible:
            return True

        return False