def train(self, tasks, browser): env = WebBotEnv(tasks=tasks, browser=browser) stats = [] found_tasklets = {} try: for episode in range(1, self.n_episodes + 1): env.reset() task = env.current_task.snapshot() self.logger.info("Episode %d/%d, task: %s" % (episode, self.n_episodes, task.task_str)) max_reward = 0 while True: if task.done or task.reward < -10: break env.render() actions = task.get_preferred_actions() if len(actions) == 0: actions = task.state.possible_actions candidate_actions = [] action_scores = {} for action in actions: action_score = self._get_action_score(task, action) action_scores[action] = action_score if action_score > 0.1: candidate_actions.append(action) if len(candidate_actions) == 0: candidate_actions = Utils.top_n(action_scores, 5, reverse=True) action = random.choice(candidate_actions) env.step(action) self._set_action_score(task, action, score=task.reward) task_ = env.current_task.snapshot() task = task_ self.logger.info("\taction:%s, %s" % (action, task.get_reward_str())) tasklet = task.get_tasklet() if tasklet not in found_tasklets: found_tasklets[tasklet] = (task.total_reward, episode, task.state.screenshot) if task.total_reward > max_reward: max_reward = task.total_reward stats.append([episode, max_reward]) self.logger.info("Episode %d/%d, max_reward %.2f" % (episode, self.n_episodes, max_reward)) env.destroy() except Exception as e: import traceback traceback.print_exc() self.logger.info("failed with error: %s" % e) return found_tasklets
def _load_samples(self, tasks, browser): from environment import WebBotEnv for task in tasks: env = WebBotEnv(tasks=[task], browser=browser) env.replay() for i in range(len(task.state_history)): task_i = task.snapshot(step=i) action_i = task.action_history[i] if not task_i.state or not action_i or action_i.element is None: continue # self.fe.plot_feature(task_i, action_i) # feature = self.fe.get_new_feature([task_i], [action_i]) # feature_shape = self.fe.get_new_feature_shape() yield task_i, action_i, 1 for action in task_i.state.possible_actions: if action.action_str != action_i.action_str: yield task_i, action, 0
def execute(self, tasks, browser): env = WebBotEnv(tasks=tasks, browser=browser) for task in tasks: env.reset(new_task=task) task = env.current_task.snapshot() while True: if task.done: break env.render() actions = task.get_preferred_actions() action2p = self.model.predict(task, actions) action = Utils.weighted_choice(action2p) env.step(action) task_ = env.current_task.snapshot() task = task_ self.logger.info("\tExploit, action:%s, reward:%.2f, done:%s" % (action, task.reward, task.done)) self.logger.info("Got total_reward %.2f in task." % task.total_reward) env.destroy() self.logger.info("Done testing tasks.")
def execute(self, tasks, browser, visualize=False): env = WebBotEnv(tasks=tasks, browser=browser, visualize=visualize) for task in tasks: # initial observation env.reset(new_task=task) task = env.current_task.snapshot() self.logger.info("Executing task: %s" % task.task_str) while True: if task.done: break env.render() action, q = self.choose_action_with_model(task) env.step(action) # self.fe.plot_feature(task, action) task_ = env.current_task.snapshot() task = task_ self.logger.info("\tExploit, action:%s, reward:%.2f, done:%s" % (action, task.reward, task.done)) self.logger.info("Got total_reward %.2f in task: %s" % (task.total_reward, task.task_str)) self.logger.info("Done executing tasks.")
def train_and_test(train_tasks, test_tasks, args, model_param, output_dir=None): from agent import TestQTable, TestDQN from model import Qda2pModel from environment import Task from environment import ChromeBrowser, CacheBrowser, WebBotEnv, Utils, UTG # Training on other tasks is currently not supported # logging.info("Train on %d tasks:\n" % len(train_tasks) + "\n".join([task.task_str for task in train_tasks])) logging.info("Test on %d tasks:\n" % len(test_tasks) + "\n".join(task.task_str for task in test_tasks)) if args.disable_text_embed: model_param["disable_text_embed"] = args.disable_text_embed supervised_model = None if train_tasks and "train" in args.phases: if len(train_tasks) > args.num_train_tasks > 0: train_tasks = random.sample(train_tasks, args.num_train_tasks) if args.use_cache: utgs = UTG.load_utgs_from_dir("output/utg_zips") browser = CacheBrowser(utgs=utgs) else: browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) n_train_episodes = model_param[ "n_train_episodes"] if "n_train_episodes" in model_param else 200 model_param["n_episodes"] = n_train_episodes supervised_model = Qda2pModel(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, **model_param) supervised_model.train(tasks=train_tasks, browser=browser) browser.close() if test_tasks and "replay" in args.phases: if args.use_cache: utgs = UTG.load_utgs_from_dir("output/utg_zips") browser = CacheBrowser(utgs=utgs) else: browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) for test_task in test_tasks: env = WebBotEnv(tasks=[test_task], browser=browser) env.replay() replay_tasklet = test_task.get_tasklet() logging.info(" replay finished:\n%s" % replay_tasklet) if "eval_reward" in args.phases: from environment import State, Action, ACTION_RE import traceback import numpy as np # utgs = UTG.load_utgs_from_dir("output/pv_utg_zips") # browser = CacheBrowser(utgs=utgs) trace_dir = args.task_dir states_dir = os.path.join(trace_dir, "states") results_dir = os.path.join(trace_dir, "results") if not os.path.exists(states_dir): states_dir = os.path.join(trace_dir, "states.zip") if not os.path.exists(results_dir): os.makedirs(results_dir) def load_action(state, action_line): m = re.match(ACTION_RE, action_line) action_type, value, target_locator = m.group(1), m.group( 2), m.group(3) target_ele = state.get_element_by_locator(target_locator) action = Action(element=target_ele, action_type=action_type, value=value) return action def compute_reward(task, trace_lines): # logging.info(f"compute_reward starts at {datetime.now()}") states = [] actions = [] # browser.reset(task.start_url) state_action_lines = [(line[:(line.find(": "))], line[(line.find(": ") + 2):]) for line in trace_lines] current_state_str, action_line = state_action_lines[0] current_state = State.load(states_dir, current_state_str) actions.append("RESET") states.append(current_state) task.reset(current_state, update_utg=False) last_action = load_action(current_state, action_line) actions.append(action_line) end_reached = False correct_rewards = [0] incorrect_rewards = [task.total_reward] for state_str, action_line in state_action_lines[1:]: current_state = State.load(states_dir, state_str) states.append(current_state) task.update(last_action, current_state, update_utg=False) if task.target_achieved: correct_rewards.append(task.total_reward) else: incorrect_rewards.append(task.total_reward) if action_line == "END": end_reached = True break else: last_action = load_action(current_state, action_line) max_correct_reward = max(correct_rewards) max_incorrect_reward = max(incorrect_rewards) logging.info( f" task got correct reward {max_correct_reward:6.3f}" f" and incorrect reward {max_incorrect_reward:3.3f}: {task.name}" ) return max_correct_reward, max_incorrect_reward def compute_rewards(task, traces): correct_rewards = [] incorrect_rewards = [] new_traces = [] for trace in traces: if len(trace) == 0: continue try: correct_reward, incorrect_reward = compute_reward( task, trace.splitlines()) correct_rewards.append(correct_reward) incorrect_rewards.append(incorrect_reward) except: # traceback.print_exc() # logging.warning(f"compute_reward failed for {task.task_str}:\n{trace}") pass return correct_rewards, incorrect_rewards all_traces = {} for task_name in os.listdir(trace_dir): if not task_name.endswith(".json"): continue task_file_path = os.path.join(trace_dir, task_name) trace_file = open(task_file_path) try: trace_dict = json.load(trace_file) except: logging.warning(f"unable to load task: {task_name}") continue correct_traces = trace_dict["correct_traces"] incorrect_traces = trace_dict["incorrect_traces"] if len(correct_traces) > 3: trace2len = { trace: len(trace.splitlines()) for trace in correct_traces } correct_traces = Utils.top_n(trace2len, 3) num_incorrect_traces = len(correct_traces) * 10 if len(incorrect_traces) > num_incorrect_traces: incorrect_traces = incorrect_traces[:num_incorrect_traces] correct_trace_str = "\n\n".join(correct_traces) filtered_incorrect_traces = [] for trace in incorrect_traces: if trace not in correct_trace_str: filtered_incorrect_traces.append(trace) incorrect_traces = filtered_incorrect_traces if len(correct_traces) > 0 and len(incorrect_traces) > 0: all_traces[task_name] = {} all_traces[task_name]["correct_traces"] = correct_traces all_traces[task_name]["incorrect_traces"] = incorrect_traces def get_reward_results(tag): logging.info( f"evaluating reward function for configuration: {tag}") all_rewards = {} for task_name in all_traces: logging.info(f" computing reward for {task_name}") task_file_path = os.path.join(trace_dir, task_name) test_task = Task.load(task_file_path) correct_traces = all_traces[task_name]["correct_traces"] incorrect_traces = all_traces[task_name]["incorrect_traces"] if len(correct_traces) == 0 or len(incorrect_traces) == 0: continue correct_rewards, incorrect_rewards = compute_rewards( test_task, correct_traces + incorrect_traces) if len(correct_rewards) > 0 and len(incorrect_rewards) > 0: all_rewards[task_name] = {} all_rewards[task_name]["correct_rewards"] = correct_rewards all_rewards[task_name][ "incorrect_rewards"] = incorrect_rewards logging.info(test_task.task_str) logging.info(f"correct: {len(correct_traces)} " f"max: {np.max(correct_rewards)}") logging.info(f"incorrect: {len(incorrect_traces)} " f"max: {np.max(incorrect_rewards)}") result_file_path = os.path.join(results_dir, f"{tag}.json") with open(result_file_path, "w") as f: json.dump(all_rewards, f, indent=2) get_reward_results(args.reward_weights) if test_tasks and "collect_pv_traces" in args.phases: if args.use_cache: utgs = UTG.load_utgs_from_dir("output/utg_zips") browser = CacheBrowser(utgs=utgs) else: browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) pv_trace_dir = os.path.join("output", "pv_traces") pv_states_dir = os.path.join(pv_trace_dir, "states") if not os.path.exists(pv_trace_dir): os.mkdir(pv_trace_dir) if not os.path.exists(pv_states_dir): os.mkdir(pv_states_dir) def get_replay_trace(task, states_dir=None): replay_trace = [] for i, action in enumerate(task.action_history): state = task.state_history[i] replay_trace.append(f"{state.state_str}: {action.replay_api}") if states_dir: state.save(states_dir) replay_trace.append(f"{task.state.state_str}: END") if states_dir: task.state.save(states_dir) return "\n".join(replay_trace) def explore_task(n_episodes, policy, states_dir=None): for episode in range(1, n_episodes + 1): env.reset() task = env.current_task.snapshot() target_achieved = False while True: if task.done: break env.render() preferred_actions = env.browser._filter_actions( task.state, task.get_preferred_actions()) possible_actions = env.browser._filter_actions( task.state, task.state.possible_actions) if len(possible_actions) == 0: break if len(preferred_actions) == 0: preferred_actions = possible_actions if policy == "similarity": action_scores = {} for action in preferred_actions: action_scores[action] = task.get_action_usefulness( action) action = Utils.weighted_choice(action_scores) elif policy == "random": action = random.choice(possible_actions) elif policy == "random_restricted": action = random.choice(preferred_actions) elif policy == "demo_biased": actions_in_demo = [] for action in preferred_actions: if action.replay_api in task.demonstration: actions_in_demo.append(action) rand = random.uniform(0, 1) if len(actions_in_demo) > 0 and rand < 0.5: action = random.choice(actions_in_demo) else: action = random.choice(preferred_actions) else: action = random.choice(possible_actions) env.step(action) if task.target_achieved: target_achieved = True task_ = env.current_task.snapshot() task = task_ replay_trace = get_replay_trace(task, states_dir) if target_achieved: correct_traces.append(replay_trace) if not target_achieved: incorrect_traces.append(replay_trace) task_traces = {} for test_task in test_tasks: assert isinstance(test_task, Task) if not test_task.replayable: continue task_file_path = os.path.join(pv_trace_dir, test_task.name + "task.json") if os.path.exists(task_file_path): continue try: correct_traces = [] incorrect_traces = [] env = WebBotEnv(tasks=[test_task], browser=browser) env.replay() if test_task.target_achieved: correct_traces.append( get_replay_trace(test_task, pv_states_dir)) explore_task(10, "similarity", pv_states_dir) explore_task(10, "random", pv_states_dir) explore_task(10, "random_restricted", pv_states_dir) explore_task(10, "demo_biased", pv_states_dir) GLOBAL_CONFIGS["semantic_similarity"] = True explore_task(10, "similarity", pv_states_dir) GLOBAL_CONFIGS["semantic_similarity"] = False correct_traces = list(set(correct_traces)) incorrect_traces = list(set(incorrect_traces)) task_traces[test_task] = [correct_traces, incorrect_traces] task_dict = test_task.to_dict(as_demo=True) task_dict["correct_traces"] = correct_traces task_dict["incorrect_traces"] = incorrect_traces with open(task_file_path, "w") as task_file: json.dump(task_dict, task_file, indent=2) logging.info( f" collected {len(correct_traces)} correct traces and" f" {len(incorrect_traces)} incorrect traces" f" in task {test_task.name}") except: logging.info( f" failed to collect traces in task {test_task.name}") if test_tasks and "crawl" in args.phases: if args.use_cache: utgs = UTG.load_utgs_from_dir("output/utg_zips") browser = CacheBrowser(utgs=utgs) else: browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) n_test_episodes = model_param[ "n_test_episodes"] if "n_test_episodes" in model_param else 100 model_param["n_episodes"] = n_test_episodes if args.explore_policy: model_param["explore_policy"] = args.explore_policy if args.explore_rate: model_param["explore_rate"] = args.explore_rate test_results = [] for test_task in test_tasks: # if not test_task.demonstration: # continue env = WebBotEnv(tasks=[test_task], browser=browser) env.replay() demo_tasklet = test_task.get_tasklet() logging.info(" demonstration:\n%s" % demo_tasklet) if args.algorithm == "dqn": agent = TestDQN(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, supervised_model=supervised_model, **model_param) elif args.algorithm == "hill_climbing": from baseline_agents import HillClimbing agent = HillClimbing(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, **model_param) elif args.algorithm == "monte_carlo": from baseline_agents import MonteCarlo agent = MonteCarlo(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, **model_param) elif args.algorithm == "random": model_param["explore_policy"] = "full_rand" model_param["explore_rate"] = 1.0 agent = TestDQN(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, supervised_model=supervised_model, **model_param) else: agent = TestQTable(data_dir=args.data_dir, log_dir=args.log_dir, model_dir=args.model_dir, supervised_model=supervised_model, **model_param) found_tasklets = agent.train(tasks=[test_task], browser=browser) top_tasklets = Utils.top_n(found_tasklets, 10, reverse=True) post_processed_tasklets = {} for i, tasklet in enumerate(top_tasklets): original_total_reward, episode, original_final_screen = found_tasklets[ tasklet] tasklet, total_reward, final_screen = env.post_process( test_task, tasklet) if total_reward > original_total_reward: logging.debug("post-processing improved the total reward.") post_processed_tasklets[tasklet] = (total_reward, episode, final_screen) top_tasklets = Utils.top_n(post_processed_tasklets, 5, reverse=True) test_results.append(top_tasklets) task_output_dir = None if output_dir: task_output_dir = os.path.join(output_dir, test_task.name) if not os.path.exists(task_output_dir): os.mkdir(task_output_dir) test_report = "\n" + "=" * 50 + "\n demonstration:\n%s\n" % demo_tasklet for i, tasklet in enumerate(top_tasklets): total_reward, episode, final_screen = post_processed_tasklets[ tasklet] test_report += "-" * 50 + "\n tasklet-%d (episode %d):\n%s\n" % ( i, episode, tasklet) if task_output_dir: try: final_screen_path = os.path.join( task_output_dir, "final_state_tasklet-%d.png" % i) final_screen.save(final_screen_path) except: pass logging.info(test_report) if task_output_dir: try: task_report_path = os.path.join(task_output_dir, "report.txt") with open(task_report_path, "w") as f: f.write(test_report) except: pass result_lines = [] success_ids = [] for i in range(len(test_results)): top_tasklets = test_results[i] test_task = test_tasks[i] succeed_id = -1 for i, tasklet in enumerate(top_tasklets): if "success:Y" in tasklet: succeed_id = i success_ids.append(succeed_id) result_lines.append( "success:{:s} tasklet:{:d} #episodes:{:3d} #demo_steps:{:2d} task:{:60s} {:s}" .format( "Y" if succeed_id >= 0 else "N", succeed_id, -1, # TODO get the episode number len(test_task.demonstration), test_task.name, test_task.task_str)) success_count = sum([(1 if succeed_id > 0 else 0) for succeed_id in success_ids]) success_rate = (float(success_count) / len(test_results)) if len(test_results) > 0 else 0 result_lines.append("Success rate: %.3f" % success_rate) overall_report = "Result:\n" + "\n".join(result_lines) logging.info(overall_report) if output_dir: try: overall_report_path = os.path.join(output_dir, "overall.txt") with open(overall_report_path, "w") as f: f.write(overall_report) except: pass if test_tasks and "analyze_task_complexity" in args.phases: browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) for test_task in test_tasks: env = WebBotEnv(browser=browser, tasks=[test_task]) env.analyze_task_complexity(test_task, test_task.demonstration) if test_tasks and "build_cache" in args.phases: from environment import UTG browser = ChromeBrowser(wait=args.wait, headless=args.headless, proxy=args.proxy, restart_reset=args.restart_reset, chrome_path=args.chrome_path, extension_path=args.extension_path) for test_task in test_tasks: task_category = test_task.name.split("_")[0] task_host = Utils.get_host(test_task.start_url) logging.info("building cache for %s" % test_task.task_str) env = WebBotEnv(tasks=[test_task], browser=browser) # save UTG to dir utg_dir_path = os.path.join( "output", "utgs", "%s_%s_utg" % (task_category, task_host)) test_task.utg = UTG.load_from_dir(utg_dir_path) test_task.utg.save_states = True test_task.utg.start_url = test_task.start_url logging.info("replaying the demonstration") env.replay() test_task.utg.save() logging.info("exploring other paths: demo_biased strategy") env.explore(policy="demo_biased", n_episodes=50) test_task.utg.save() logging.info("exploring other paths: similarity strategy") env.explore(policy="similarity", n_episodes=50) test_task.utg.save() logging.info("exploring other paths: random strategy") env.explore(policy="random", n_episodes=50) test_task.utg.save() utg_zip_path = os.path.join( "output", "utg_zips", "%s_%s_utg" % (task_category, task_host)) if not os.path.exists(os.path.dirname(utg_zip_path)): os.makedirs(os.path.dirname(utg_zip_path)) shutil.make_archive(base_name=utg_zip_path, format="zip", root_dir=utg_dir_path) logging.info("done building cache for %s" % test_task.task_str) logging.info("UTG saved to %s" % utg_zip_path) # pv_utg_zip_dir = os.path.join("output", "pv_utg_zips") # shutil.copy(utg_zip_path, pv_utg_zip_dir) env.destroy()
def train(self, tasks, browser): env = WebBotEnv(tasks=tasks, browser=browser) stats = [] def save_progress(save_stats=True, save_fig=True, save_model=False, save_memory=False): try: if save_stats: stats_path = os.path.join(self.model_dir, "training_stats.json") json.dump(stats, open(stats_path, "w"), indent=2) if save_fig: stats_png_path = os.path.join(self.log_dir, "training_stats.png") self._plot_training_stats(stats, self.et.n_explore_episodes, stats_png_path) if save_model: self.save_model() if save_memory: self.replay_memory.save(self.model_dir) except Exception as e: self.logger.warning(e) def resume_progress(): # resume model self.load_model() # resume memory self.replay_memory.load(self.model_dir) # resume stats stats_path = os.path.join(self.model_dir, "training_stats.json") if os.path.exists(stats_path): stats.append(json.load(open(stats_path))) if self.resume: resume_progress() if self.demo_dir: self.demo_memory.load(self.demo_dir) for task in tasks: self.demo_memory.update_rewards(task) for i in range(self.demo_pretrain_steps): self._learn(memory_source="demo") self.logger.info("Done pre-training on demos.") found_tasklets = {} for episode in range(1, self.n_episodes + 1): # initial observation env.reset() task = env.current_task.snapshot() self.logger.info("Episode %d/%d, task: %s" % (episode, self.n_episodes, task.task_str)) max_reward = 0 while True: # break while loop when end of this episode if task.done or task.reward < -10: break env.render() epsilon = self.et.get_epsilon(episode, task) # RL choose action based on current task snapshot if np.random.uniform() < epsilon: action_type = "Explore" action = self.et.choose_action_to_explore(task) else: action_type = "Exploit" action, q = self.choose_action_with_model( task, q_func=self.q_eval) env.step(action) # self.fe.plot_feature(task, action) task_ = env.current_task.snapshot() self.replay_memory.store_transition( Transition(task=task, action=action, task_=task_)) # swap observation task = task_ self.logger.info( "\t%s, epsilon:%.3f, action:%s, %s" % (action_type, epsilon, action, task.get_reward_str())) tasklet = task.get_tasklet() if tasklet not in found_tasklets: found_tasklets[tasklet] = (task.total_reward, episode, task.state.screenshot) if task.total_reward > max_reward: max_reward = task.total_reward if episode > self.et.n_explore_episodes: max_q, q_error = self._learn() else: max_q, q_error = None, None epsilon = self.et.get_epsilon(episode=episode) stats.append([episode, epsilon, max_reward, max_q, q_error]) self.logger.info( "Episode %d/%d, epsilon %.3f, max_reward %.2f, max_q %.3f, q_error %.3f" % (episode, self.n_episodes, epsilon, max_reward, max_q or np.nan, q_error or np.nan)) if episode % self.n_backup_episodes == 0: save_progress(save_fig=True, save_model=False, save_memory=False) save_progress(save_fig=True, save_model=True, save_memory=False) env.destroy() return found_tasklets
def train(self, tasks, browser): env = WebBotEnv(tasks=tasks, browser=browser) stats = [] def save_progress(save_stats=True, save_fig=True, save_model=True): try: if save_stats: stats_path = os.path.join(self.model_dir, "training_stats.json") json.dump(stats, open(stats_path, "w"), indent=2) if save_fig: stats_png_path = os.path.join(self.log_dir, "training_stats.png") self._plot_training_stats(stats, stats_png_path) if save_model: self.save_model() except Exception as e: self.logger.warning(e) def resume_progress(): # resume model self.load_model() stats_path = os.path.join(self.model_dir, "training_stats.json") if os.path.exists(stats_path): stats.append(json.load(open(stats_path))) if self.resume: resume_progress() found_tasklets = {} try: for episode in range(1, self.n_episodes + 1): env.reset() task = env.current_task.snapshot() self.logger.info("Episode %d/%d, task: %s" % (episode, self.n_episodes, task.task_str)) max_reward = 0 max_reward_task_snapshot = None tried_form_actions = [] while True: if task.done or task.reward < -10: break env.render() epsilon = self.et.get_epsilon(episode, task) # if episode == 1 and self.et.supervised_model and self.et.explore_policy == "supervised": # action_type = "Guided" # action = self.choose_action_with_supervised_model(task) # el candidate_actions = [] interacted_form_ids = [ form.unique_id for form, _ in tried_form_actions ] for candidate_action in self.get_candidate_actions(task): if isinstance(candidate_action, FormAction) and \ candidate_action.form.unique_id in interacted_form_ids: continue candidate_actions.append(candidate_action) if len(candidate_actions) == 0: break rand = np.random.uniform() action_category, action, q = "Unknown", None, 0 if rand > epsilon: action_category = "Exploit" action, q = self.choose_action_with_model( task, candidate_actions) if rand <= epsilon or q == 0: action_category = "Explore" action = self.choose_action_to_explore( task, candidate_actions) # self.fe.plot_feature(task, action) if action is None: break init_task = task.snapshot() if isinstance(action, FormAction): form = action.form form_actions, action_categories = form.try_solve( epsilon) init_reward = task.total_reward for i, form_action in enumerate(form_actions): form_action_element = task.state.get_element_by_locator( form_action.element.locator) if form_action_element is None: form_action.value = None if form_action.value is None: continue env.step(form_action) task = env.current_task.snapshot() self.logger.info( "\t%s, epsilon:%.3f, action:%s, %s" % (action_categories[i], epsilon, form_action, task.get_reward_str())) tried_form_actions.append((form, form_actions)) self.logger.info( f" {action} achieved {task.total_reward - init_reward:.2f}" ) else: env.step(action) task = env.current_task.snapshot() self.logger.info("\t%s, epsilon:%.3f, action:%s, %s" % (action_category, epsilon, action, task.get_reward_str())) self._learn(init_task, action, task) if task.total_reward > max_reward: max_reward = task.total_reward max_reward_task_snapshot = task.snapshot() if max_reward_task_snapshot is not None: max_reward_tasklet = max_reward_task_snapshot.get_tasklet() if max_reward_tasklet not in found_tasklets: found_tasklets[max_reward_tasklet] = \ (max_reward_task_snapshot.total_reward, episode, max_reward_task_snapshot.state.screenshot) # learn form for (form, form_actions) in tried_form_actions: form.store_actions_actual_reward(form_actions, max_reward) self.form_manager.learn() epsilon = self.et.get_epsilon(episode=episode) stats.append([episode, epsilon, max_reward]) if episode % self.n_backup_episodes == 0: save_progress(save_fig=True, save_model=False) self.logger.info( "Episode %d/%d, epsilon %.3f, max_reward %.2f" % (episode, self.n_episodes, epsilon, max_reward)) save_progress(save_fig=True, save_model=True) env.destroy() except Exception as e: import traceback traceback.print_exc() self.logger.info("failed with error: %s" % e) return found_tasklets