def __init__(self, data, graph, config, attr_encoder, is_uncertain=cmd_args.prob_dataset): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.config = config self.attr_encoder = attr_encoder # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True if is_uncertain: self.interp = SceneInterp(graph.scene, config, is_uncertain=True) else: self.interp = SceneInterp(graph.scene["ground_truth"], config, is_uncertain=False) self.state = self.interp.get_init_state() self.is_uncertain = is_uncertain
def get_option_num(self, clauses, scene, target): scene_interp = SceneInterp(scene, self.config) binding_dict = scene_interp.fast_query(clauses) success = check_success(binding_dict, target) possible = check_possible(binding_dict, target) option_num = len(binding_dict["var_0"]) if "var_0" in binding_dict.keys() else len(scene["objects"]) return option_num, success, possible
def eu_solve_prog(scene, config, target, max_depth): interpreter = SceneInterp(scene, config) clean_state = interpreter.get_init_state() prog = eu_synthesize(interpreter.interpreter, target, clean_state, is_uncertain=False, max_depth=max_depth) return prog
def __init__(self, data, graph, config, attr_encoder): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.config = config self.attr_encoder = attr_encoder # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True self.state = None self.interp = SceneInterp(graph.scene, config) self.interp_state = self.interp.get_init_state()
def add_attr(prog, scene): scene_interp = SceneInterp(scene, config) state = scene_interp.get_state(prog) bd_list = scene_interp.get_valid_binding(state) new_clauses = [] ref = get_ref(prog) for var, obj in enumerate(bd_list): if var not in ref: continue color = scene["objects"][obj]["color"] shape = scene["objects"][obj]["shape"] new_clauses.append(["color", color, var]) new_clauses.append(["shape", shape, var]) prog += new_clauses return prog
def __init__(self, data, graph, config, attr_encoder, state=None, ref=False, is_uncertain=False): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.idx_selected = [] self.config = config self.attr_encoder = attr_encoder self.actions = get_all_clauses(config) self.create_action_dict() self.ref_flag = ref if ref: self.ref = [0] else: self.ref = list(range(cmd_args.max_var_num)) # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True self.unreachable = [] self.reachable_dict, self.unreachable_dict = get_reachable_dict( self.actions) self.is_uncertain = is_uncertain if is_uncertain: self.interp = SceneInterp(graph.scene, config, is_uncertain=True) else: self.interp = SceneInterp(graph.scene["ground_truth"], config, is_uncertain=False) if type(state) == type(None): self.state = self.interp.get_init_state() else: self.state = state
def shrink(prog, scene, config): scene_interp = SceneInterp(scene, config) last_prog = prog while (True): prog = remove_one(scene_interp, prog) # is minimal if type(prog) == type(None): return last_prog last_prog = prog
def cont_single(file_name, datapoint, model, graphs, save_dir, attr_encoder, config, save=True, n=10): file_path = os.path.join(save_dir, str(file_name)) exist = os.path.exists(file_path) logging.info(f"task: {file_name}, {exist}") if exist: logging.info("skip") return total_ct = 0 memory = ReplayMemory(10000) method = None policy = copy.deepcopy(model.policy) target = copy.deepcopy(policy) target.eval() optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr) success_progs = [] all_progs = [] eps = cmd_args.eps graph = graphs[datapoint.graph_id] for it in range(cmd_args.test_iter): suc, prog, obj_left = episode(policy, target, datapoint, graph, config, attr_encoder, memory, total_ct, optimizer) total_ct += 1 if suc: success_progs.append(prog) method = "model" break if not suc: all_progs.append((prog, obj_left)) logging.info(f"success: {success_progs}") # if not success, explot the best selections. scene_interpreter = SceneInterp(graph.ground_truth_scene, config) if len(success_progs) == 0: logging.info("start exploit") # take second element for sort candidates = sorted(all_progs, key=lambda x: x[1]) for prog, obj_left in candidates: # no progress if len(prog) == 0: logging.info("empty prog to exploit, stop") break logging.info(f"prog: {prog}") logging.info(f"obj left: {obj_left}") prog_left = scene_interpreter.interp_enum(prog, datapoint.y, max_depth=2) if not type(prog_left) == type(None): print(prog_left) prog.append(prog_left) success_progs.append(prog) method = f"exploit: {len(prog_left)}" logging.info( f"We found: {success_progs} at depth {len(prog_left)}") break if save: file_path = os.path.join(save_dir, str(file_name)) with open(file_path, 'w') as suc_prog_file: res = {} res['prog'] = success_progs res['method'] = method json.dump(res, suc_prog_file) return success_progs
class Env(): def __init__(self, data, graph, config, attr_encoder, is_uncertain=cmd_args.prob_dataset): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.config = config self.attr_encoder = attr_encoder # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True if is_uncertain: self.interp = SceneInterp(graph.scene, config, is_uncertain=True) else: self.interp = SceneInterp(graph.scene["ground_truth"], config, is_uncertain=False) self.state = self.interp.get_init_state() self.is_uncertain = is_uncertain # TODO: check the effect of update data on the whole dataset def update_data(self, binding_dict): self.graph.update_binding(binding_dict) # self.data.update_data(self.graph, self.attr_encoder) x = self.attr_encoder.get_embedding( [node.name for node in self.graph.nodes]) edge_index, edge_types = self.graph.get_edge_info() # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types)) edge_attr = torch.tensor(edge_types, dtype=torch.float) edge_index = torch.tensor(edge_index) x = torch.tensor(x) batch = self.data.batch # print(f"previous edge num: {len(self.data.edge_attr)}") self.data = Data(x=x, y=self.data.y, edge_index=edge_index, edge_attr=edge_attr) self.data.batch = batch # print(f"previous edge num: {len(self.data.edge_attr)}") def check_success(self, binding_dict): # print(binding_dict) if not self.obj_poss_left[-1] == 1: return if "var_0" not in binding_dict.keys(): return if binding_dict["var_0"][0] == str(int(self.data.y)): # print("success!") self.success = True def check_possible(self, binding_dict): if self.success: return if "var_0" not in binding_dict.keys(): return if "var_0" in binding_dict.keys() and str(int( self.data.y)) in binding_dict["var_0"]: return self.possible = False def step(self): if len(self.clauses) == 0: return torch.tensor(0, dtype=torch.float32) binding_dict, new_state = self.interp.state_query( self.state, self.clauses[-1]) self.state = new_state # update the success and possible fields self.check_possible(binding_dict) if not self.possible: self.obj_poss_left.append(0) elif not "var_0" in binding_dict.keys(): self.obj_poss_left.append(self.obj_nums) self.update_data(binding_dict) else: self.obj_poss_left.append(len(binding_dict["var_0"])) self.update_data(binding_dict) self.check_success(binding_dict) # TODO: the reward function need to be updated reward = get_reward(self) # print(reward) return torch.tensor(reward, dtype=torch.float32) def is_finished(self): # edge case handling if len(self.obj_poss_left) == 1: return False # no possibilities if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1: return True # already succeed if self.success: return True # not possible if not self.possible: return True return False
def cont_single(file_name, datapoint, policy, graphs, save_dir, attr_encoder, config, save=True, n=10): policy = deepcopy(policy) policy.train() optimizer = torch.optim.Adam(policy.parameters(), lr=cmd_args.lr) success_progs = [] all_progs = [] global sub_ct eps = cmd_args.eps graph = graphs[datapoint.graph_id] for it in range(cmd_args.episode_iter): sub_ct = 0 env, loss = fit_one(policy, datapoint, graph, eps, attr_encoder, config) if env.success: success_progs.append(env.clauses) if not env.possible: all_progs.append((env.clauses[:-1], env.obj_poss_left[-2])) if it % cmd_args.batch_size == 0: optimizer.zero_grad() loss.backward() if it % cmd_args.batch_size == 0: optimizer.step() if len(success_progs) > 0: break if save: file_path = os.path.join(save_dir, str(file_name)) with open(file_path, 'w') as suc_prog_file: json.dump(success_progs, suc_prog_file) logging.info(f"success: {success_progs}") # print(f"all: {all_progs}") # if not success, explot the best selections. scene_interpreter = SceneInterp(graph.scene, config) if len(success_progs) == 0: logging.info("start exploit") # take second element for sort candidates = sorted(all_progs, key=lambda x: x[1]) for prog, obj_left in candidates: # no progress if len(prog) == 0: logging.info("empty prog to exploit, stop") break logging.info(f"prog: {prog}") logging.info(f"obj left: {obj_left}") prog_left = scene_interpreter.interp_enum(prog, datapoint.y) if not type(prog_left) == type(None): print(prog_left) prog.append(prog_left) success_progs.append(prog) logging.info(f"We found: {success_progs}") break return success_progs
class Env(): def __init__(self, data, graph, config, attr_encoder): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.config = config self.attr_encoder = attr_encoder # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True self.state = None self.interp = SceneInterp(graph.scene, config) self.interp_state = self.interp.get_init_state() def check_success(self, binding_dict): if not self.obj_poss_left[-1] == 1: return if "var_0" not in binding_dict.keys(): return if list(binding_dict["var_0"])[0] == int(self.data.y): # print("success!") self.success = True def check_possible(self, binding_dict): if self.success: return if "var_0" not in binding_dict.keys(): return if "var_0" in binding_dict.keys() and int( self.data.y) in binding_dict["var_0"]: return self.possible = False def step(self): if len(self.clauses) == 0: return torch.tensor(0, dtype=torch.float32) # binding_dict = query(self.graph.scene, self.clauses, self.config) next_clause = self.clauses[-1] useful = self.interp.useful_check(self.interp_state, next_clause) logging.info(f"useful: {useful}") binding_dict, new_state = self.interp.state_query( self.interp_state, next_clause) self.interp_state = new_state # update the success and possible fields self.check_possible(binding_dict) if not self.possible: self.obj_poss_left.append(0) elif not "var_0" in binding_dict.keys(): self.obj_poss_left.append(self.obj_nums) else: self.obj_poss_left.append(len(binding_dict["var_0"])) self.check_success(binding_dict) done = self.success or (not self.possible) if not useful: reward = torch.tensor(-1, dtype=torch.float32) else: reward = torch.tensor(get_reward(self), dtype=torch.float32) return reward def is_finished(self): # edge case handling if len(self.obj_poss_left) == 1: return False # # duplicate clauses # dupes = [x for n, x in enumerate(self.clauses) if x in self.clauses[:n]] # if len(dupes) > 0: # return True # no possibilities if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1: return True # already succeed if self.success: return True # not possible if not self.possible: return True return False
def check_prog_correct(scene, prog, target, config): interp = SceneInterp(scene["ground_truth"], config, is_uncertain=False) binding_dict = interp.fast_query(prog) res = check_success(binding_dict, target) return res
class Env(): def __init__(self, data, graph, config, attr_encoder, state=None, ref=False): self.data = deepcopy(data) self.graph = deepcopy(graph) self.obj_nums = len(graph.scene["objects"]) self.clauses = [] self.idx_selected = [] self.config = config self.attr_encoder = attr_encoder self.actions = get_all_clauses(config) self.create_action_dict() self.ref_flag = ref if ref: self.ref = [0] else: self.ref = list(range(cmd_args.max_var_num)) # self.obj_poss_left = [ self.obj_nums ** cmd_args.max_var_num ] self.obj_poss_left = [self.obj_nums] self.success = False self.possible = True self.unreachable = [] self.reachable_dict, self.unreachable_dict = get_reachable_dict( self.actions) self.interp = SceneInterp(graph.scene, config) if type(state) == type(None): self.state = self.interp.get_init_state() else: self.state = state # def reset(self, graph): # self.update_data(graph, self.attr_encoder) def create_action_dict(self): self.action_dict = {} for action_id, action in enumerate(self.actions): self.action_dict[str(action)] = action_id # TODO: check the effect of update data on the whole dataset def update_data(self, binding_dict): self.graph.update_binding(binding_dict) # self.data.update_data(self.graph, self.attr_encoder) x = self.attr_encoder.get_embedding(self.graph.get_nodes()) edge_index, edge_types = self.graph.get_edge_info() # edge_attrs = torch.tensor(self.attr_encoder.get_embedding(edge_types)) edge_attr = torch.tensor(edge_types) edge_index = torch.tensor(edge_index) x = torch.tensor(x) batch = torch.zeros(x.shape[0], dtype=torch.int64) # print(f"previous edge num: {len(self.data.edge_attr)}") self.data = Data(x=x, y=self.data.y, edge_index=edge_index, edge_attr=edge_attr) self.data.batch = batch # print(f"previous edge num: {len(self.data.edge_attr)}") def check_success(self, binding_dict): if not self.obj_poss_left[-1] == 1: return if "var_0" not in binding_dict.keys(): return if list(binding_dict["var_0"])[0] == int(self.data.y): # print("success!") self.success = True def check_possible(self, binding_dict): if self.success: return if "var_0" not in binding_dict.keys(): return if "var_0" in binding_dict.keys() and int( self.data.y) in binding_dict["var_0"]: return self.possible = False def get_state(self): self.unreachable = self.unreachable_dict[str(sorted(self.ref))] unreachable = self.idx_selected + self.unreachable return State(self.actions, self.data, self.graph, self.state, self.action_dict, unreachable) # not support batching yet def step(self, action_idx): next_clause = self.actions[action_idx] if self.ref_flag: for element in next_clause: if type(element) == int: if element not in self.ref: self.ref.append(element) self.unreachable = self.unreachable_dict[str(sorted(self.ref))] # if 'red' in next_clause or 'blue' in next_clause: # print('here') self.idx_selected.append(action_idx) self.clauses.append(next_clause) useful = self.interp.useful_check(self.state, next_clause) logging.info(f"useful: {useful}") binding_dict, new_state = self.interp.state_query( self.state, next_clause) self.state = new_state # update the success and possible fields self.check_possible(binding_dict) if not self.possible: self.obj_poss_left.append(0) elif not "var_0" in binding_dict.keys(): self.obj_poss_left.append(self.obj_nums) self.update_data(binding_dict) else: self.obj_poss_left.append(len(binding_dict["var_0"])) self.update_data(binding_dict) # update whether done or not self.check_success(binding_dict) done = self.success or (not self.possible) if not useful: reward = torch.tensor(-1, dtype=torch.float32) else: reward = torch.tensor(get_reward(self), dtype=torch.float32) self.update_data(binding_dict) info = {} logging.info(f"selected: {self.idx_selected}") logging.info(f"done: {done}") logging.info(f"success: {self.success}") state = self.get_state() return state, reward, done, info def is_finished(self): # edge case handling if len(self.obj_poss_left) == 1: return False # # duplicate clauses # dupes = [x for n, x in enumerate(self.clauses) if x in self.clauses[:n]] # if len(dupes) > 0: # return True # no possibilities if self.obj_poss_left[-1] == 0 or self.obj_poss_left[-1] == 1: return True # already succeed if self.success: return True # not possible if not self.possible: return True return False