Exemple #1
0
    def train(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        stats = []

        found_tasklets = {}
        try:
            for episode in range(1, self.n_episodes + 1):
                env.reset()
                task = env.current_task.snapshot()
                self.logger.info("Episode %d/%d, task: %s" %
                                 (episode, self.n_episodes, task.task_str))
                max_reward = 0
                while True:
                    if task.done or task.reward < -10:
                        break
                    env.render()

                    actions = task.get_preferred_actions()
                    if len(actions) == 0:
                        actions = task.state.possible_actions
                    candidate_actions = []
                    action_scores = {}
                    for action in actions:
                        action_score = self._get_action_score(task, action)
                        action_scores[action] = action_score
                        if action_score > 0.1:
                            candidate_actions.append(action)
                    if len(candidate_actions) == 0:
                        candidate_actions = Utils.top_n(action_scores,
                                                        5,
                                                        reverse=True)
                    action = random.choice(candidate_actions)

                    env.step(action)
                    self._set_action_score(task, action, score=task.reward)

                    task_ = env.current_task.snapshot()
                    task = task_
                    self.logger.info("\taction:%s, %s" %
                                     (action, task.get_reward_str()))

                    tasklet = task.get_tasklet()
                    if tasklet not in found_tasklets:
                        found_tasklets[tasklet] = (task.total_reward, episode,
                                                   task.state.screenshot)
                    if task.total_reward > max_reward:
                        max_reward = task.total_reward

                stats.append([episode, max_reward])
                self.logger.info("Episode %d/%d, max_reward %.2f" %
                                 (episode, self.n_episodes, max_reward))
            env.destroy()
        except Exception as e:
            import traceback
            traceback.print_exc()
            self.logger.info("failed with error: %s" % e)
        return found_tasklets
Exemple #2
0
 def _load_samples(self, tasks, browser):
     from environment import WebBotEnv
     for task in tasks:
         env = WebBotEnv(tasks=[task], browser=browser)
         env.replay()
         for i in range(len(task.state_history)):
             task_i = task.snapshot(step=i)
             action_i = task.action_history[i]
             if not task_i.state or not action_i or action_i.element is None:
                 continue
             # self.fe.plot_feature(task_i, action_i)
             # feature = self.fe.get_new_feature([task_i], [action_i])
             # feature_shape = self.fe.get_new_feature_shape()
             yield task_i, action_i, 1
             for action in task_i.state.possible_actions:
                 if action.action_str != action_i.action_str:
                     yield task_i, action, 0
    def execute(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        for task in tasks:
            env.reset(new_task=task)
            task = env.current_task.snapshot()

            while True:
                if task.done:
                    break
                env.render()

                actions = task.get_preferred_actions()
                action2p = self.model.predict(task, actions)
                action = Utils.weighted_choice(action2p)
                env.step(action)
                task_ = env.current_task.snapshot()
                task = task_
                self.logger.info("\tExploit, action:%s, reward:%.2f, done:%s" %
                                 (action, task.reward, task.done))
            self.logger.info("Got total_reward %.2f in task." %
                             task.total_reward)
        env.destroy()
        self.logger.info("Done testing tasks.")
 def execute(self, tasks, browser, visualize=False):
     env = WebBotEnv(tasks=tasks, browser=browser, visualize=visualize)
     for task in tasks:
         # initial observation
         env.reset(new_task=task)
         task = env.current_task.snapshot()
         self.logger.info("Executing task: %s" % task.task_str)
         while True:
             if task.done:
                 break
             env.render()
             action, q = self.choose_action_with_model(task)
             env.step(action)
             # self.fe.plot_feature(task, action)
             task_ = env.current_task.snapshot()
             task = task_
             self.logger.info("\tExploit, action:%s, reward:%.2f, done:%s" %
                              (action, task.reward, task.done))
         self.logger.info("Got total_reward %.2f in task: %s" %
                          (task.total_reward, task.task_str))
     self.logger.info("Done executing tasks.")
Exemple #5
0
def train_and_test(train_tasks,
                   test_tasks,
                   args,
                   model_param,
                   output_dir=None):
    from agent import TestQTable, TestDQN
    from model import Qda2pModel
    from environment import Task
    from environment import ChromeBrowser, CacheBrowser, WebBotEnv, Utils, UTG

    # Training on other tasks is currently not supported
    # logging.info("Train on %d tasks:\n" % len(train_tasks) + "\n".join([task.task_str for task in train_tasks]))
    logging.info("Test on %d tasks:\n" % len(test_tasks) +
                 "\n".join(task.task_str for task in test_tasks))

    if args.disable_text_embed:
        model_param["disable_text_embed"] = args.disable_text_embed

    supervised_model = None
    if train_tasks and "train" in args.phases:
        if len(train_tasks) > args.num_train_tasks > 0:
            train_tasks = random.sample(train_tasks, args.num_train_tasks)
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        n_train_episodes = model_param[
            "n_train_episodes"] if "n_train_episodes" in model_param else 200
        model_param["n_episodes"] = n_train_episodes

        supervised_model = Qda2pModel(data_dir=args.data_dir,
                                      log_dir=args.log_dir,
                                      model_dir=args.model_dir,
                                      **model_param)
        supervised_model.train(tasks=train_tasks, browser=browser)
        browser.close()
    if test_tasks and "replay" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        for test_task in test_tasks:
            env = WebBotEnv(tasks=[test_task], browser=browser)
            env.replay()
            replay_tasklet = test_task.get_tasklet()
            logging.info(" replay finished:\n%s" % replay_tasklet)

    if "eval_reward" in args.phases:
        from environment import State, Action, ACTION_RE
        import traceback
        import numpy as np

        # utgs = UTG.load_utgs_from_dir("output/pv_utg_zips")
        # browser = CacheBrowser(utgs=utgs)

        trace_dir = args.task_dir
        states_dir = os.path.join(trace_dir, "states")
        results_dir = os.path.join(trace_dir, "results")
        if not os.path.exists(states_dir):
            states_dir = os.path.join(trace_dir, "states.zip")
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

        def load_action(state, action_line):
            m = re.match(ACTION_RE, action_line)
            action_type, value, target_locator = m.group(1), m.group(
                2), m.group(3)
            target_ele = state.get_element_by_locator(target_locator)
            action = Action(element=target_ele,
                            action_type=action_type,
                            value=value)
            return action

        def compute_reward(task, trace_lines):
            # logging.info(f"compute_reward starts at {datetime.now()}")
            states = []
            actions = []
            # browser.reset(task.start_url)
            state_action_lines = [(line[:(line.find(": "))],
                                   line[(line.find(": ") + 2):])
                                  for line in trace_lines]
            current_state_str, action_line = state_action_lines[0]
            current_state = State.load(states_dir, current_state_str)
            actions.append("RESET")
            states.append(current_state)
            task.reset(current_state, update_utg=False)
            last_action = load_action(current_state, action_line)
            actions.append(action_line)
            end_reached = False
            correct_rewards = [0]
            incorrect_rewards = [task.total_reward]
            for state_str, action_line in state_action_lines[1:]:
                current_state = State.load(states_dir, state_str)
                states.append(current_state)
                task.update(last_action, current_state, update_utg=False)
                if task.target_achieved:
                    correct_rewards.append(task.total_reward)
                else:
                    incorrect_rewards.append(task.total_reward)
                if action_line == "END":
                    end_reached = True
                    break
                else:
                    last_action = load_action(current_state, action_line)
            max_correct_reward = max(correct_rewards)
            max_incorrect_reward = max(incorrect_rewards)
            logging.info(
                f"  task got correct reward {max_correct_reward:6.3f}"
                f" and incorrect reward {max_incorrect_reward:3.3f}: {task.name}"
            )
            return max_correct_reward, max_incorrect_reward

        def compute_rewards(task, traces):
            correct_rewards = []
            incorrect_rewards = []
            new_traces = []
            for trace in traces:
                if len(trace) == 0:
                    continue
                try:
                    correct_reward, incorrect_reward = compute_reward(
                        task, trace.splitlines())
                    correct_rewards.append(correct_reward)
                    incorrect_rewards.append(incorrect_reward)
                except:
                    # traceback.print_exc()
                    # logging.warning(f"compute_reward failed for {task.task_str}:\n{trace}")
                    pass
            return correct_rewards, incorrect_rewards

        all_traces = {}
        for task_name in os.listdir(trace_dir):
            if not task_name.endswith(".json"):
                continue
            task_file_path = os.path.join(trace_dir, task_name)
            trace_file = open(task_file_path)
            try:
                trace_dict = json.load(trace_file)
            except:
                logging.warning(f"unable to load task: {task_name}")
                continue
            correct_traces = trace_dict["correct_traces"]
            incorrect_traces = trace_dict["incorrect_traces"]
            if len(correct_traces) > 3:
                trace2len = {
                    trace: len(trace.splitlines())
                    for trace in correct_traces
                }
                correct_traces = Utils.top_n(trace2len, 3)
            num_incorrect_traces = len(correct_traces) * 10
            if len(incorrect_traces) > num_incorrect_traces:
                incorrect_traces = incorrect_traces[:num_incorrect_traces]

            correct_trace_str = "\n\n".join(correct_traces)
            filtered_incorrect_traces = []
            for trace in incorrect_traces:
                if trace not in correct_trace_str:
                    filtered_incorrect_traces.append(trace)
            incorrect_traces = filtered_incorrect_traces

            if len(correct_traces) > 0 and len(incorrect_traces) > 0:
                all_traces[task_name] = {}
                all_traces[task_name]["correct_traces"] = correct_traces
                all_traces[task_name]["incorrect_traces"] = incorrect_traces

        def get_reward_results(tag):
            logging.info(
                f"evaluating reward function for configuration: {tag}")
            all_rewards = {}
            for task_name in all_traces:
                logging.info(f" computing reward for {task_name}")
                task_file_path = os.path.join(trace_dir, task_name)
                test_task = Task.load(task_file_path)
                correct_traces = all_traces[task_name]["correct_traces"]
                incorrect_traces = all_traces[task_name]["incorrect_traces"]
                if len(correct_traces) == 0 or len(incorrect_traces) == 0:
                    continue
                correct_rewards, incorrect_rewards = compute_rewards(
                    test_task, correct_traces + incorrect_traces)
                if len(correct_rewards) > 0 and len(incorrect_rewards) > 0:
                    all_rewards[task_name] = {}
                    all_rewards[task_name]["correct_rewards"] = correct_rewards
                    all_rewards[task_name][
                        "incorrect_rewards"] = incorrect_rewards

                    logging.info(test_task.task_str)
                    logging.info(f"correct: {len(correct_traces)} "
                                 f"max: {np.max(correct_rewards)}")
                    logging.info(f"incorrect: {len(incorrect_traces)} "
                                 f"max: {np.max(incorrect_rewards)}")
            result_file_path = os.path.join(results_dir, f"{tag}.json")
            with open(result_file_path, "w") as f:
                json.dump(all_rewards, f, indent=2)

        get_reward_results(args.reward_weights)

    if test_tasks and "collect_pv_traces" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)

        pv_trace_dir = os.path.join("output", "pv_traces")
        pv_states_dir = os.path.join(pv_trace_dir, "states")
        if not os.path.exists(pv_trace_dir):
            os.mkdir(pv_trace_dir)
        if not os.path.exists(pv_states_dir):
            os.mkdir(pv_states_dir)

        def get_replay_trace(task, states_dir=None):
            replay_trace = []
            for i, action in enumerate(task.action_history):
                state = task.state_history[i]
                replay_trace.append(f"{state.state_str}: {action.replay_api}")
                if states_dir:
                    state.save(states_dir)
            replay_trace.append(f"{task.state.state_str}: END")
            if states_dir:
                task.state.save(states_dir)
            return "\n".join(replay_trace)

        def explore_task(n_episodes, policy, states_dir=None):
            for episode in range(1, n_episodes + 1):
                env.reset()
                task = env.current_task.snapshot()
                target_achieved = False
                while True:
                    if task.done:
                        break
                    env.render()
                    preferred_actions = env.browser._filter_actions(
                        task.state, task.get_preferred_actions())
                    possible_actions = env.browser._filter_actions(
                        task.state, task.state.possible_actions)
                    if len(possible_actions) == 0:
                        break
                    if len(preferred_actions) == 0:
                        preferred_actions = possible_actions
                    if policy == "similarity":
                        action_scores = {}
                        for action in preferred_actions:
                            action_scores[action] = task.get_action_usefulness(
                                action)
                        action = Utils.weighted_choice(action_scores)
                    elif policy == "random":
                        action = random.choice(possible_actions)
                    elif policy == "random_restricted":
                        action = random.choice(preferred_actions)
                    elif policy == "demo_biased":
                        actions_in_demo = []
                        for action in preferred_actions:
                            if action.replay_api in task.demonstration:
                                actions_in_demo.append(action)
                        rand = random.uniform(0, 1)
                        if len(actions_in_demo) > 0 and rand < 0.5:
                            action = random.choice(actions_in_demo)
                        else:
                            action = random.choice(preferred_actions)
                    else:
                        action = random.choice(possible_actions)
                    env.step(action)
                    if task.target_achieved:
                        target_achieved = True
                    task_ = env.current_task.snapshot()
                    task = task_
                replay_trace = get_replay_trace(task, states_dir)
                if target_achieved:
                    correct_traces.append(replay_trace)
                if not target_achieved:
                    incorrect_traces.append(replay_trace)

        task_traces = {}
        for test_task in test_tasks:
            assert isinstance(test_task, Task)
            if not test_task.replayable:
                continue
            task_file_path = os.path.join(pv_trace_dir,
                                          test_task.name + "task.json")
            if os.path.exists(task_file_path):
                continue
            try:
                correct_traces = []
                incorrect_traces = []

                env = WebBotEnv(tasks=[test_task], browser=browser)
                env.replay()
                if test_task.target_achieved:
                    correct_traces.append(
                        get_replay_trace(test_task, pv_states_dir))
                explore_task(10, "similarity", pv_states_dir)
                explore_task(10, "random", pv_states_dir)
                explore_task(10, "random_restricted", pv_states_dir)
                explore_task(10, "demo_biased", pv_states_dir)
                GLOBAL_CONFIGS["semantic_similarity"] = True
                explore_task(10, "similarity", pv_states_dir)
                GLOBAL_CONFIGS["semantic_similarity"] = False
                correct_traces = list(set(correct_traces))
                incorrect_traces = list(set(incorrect_traces))
                task_traces[test_task] = [correct_traces, incorrect_traces]
                task_dict = test_task.to_dict(as_demo=True)
                task_dict["correct_traces"] = correct_traces
                task_dict["incorrect_traces"] = incorrect_traces
                with open(task_file_path, "w") as task_file:
                    json.dump(task_dict, task_file, indent=2)
                logging.info(
                    f" collected {len(correct_traces)} correct traces and"
                    f" {len(incorrect_traces)} incorrect traces"
                    f" in task {test_task.name}")
            except:
                logging.info(
                    f" failed to collect traces in task {test_task.name}")

    if test_tasks and "crawl" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        n_test_episodes = model_param[
            "n_test_episodes"] if "n_test_episodes" in model_param else 100

        model_param["n_episodes"] = n_test_episodes
        if args.explore_policy:
            model_param["explore_policy"] = args.explore_policy
        if args.explore_rate:
            model_param["explore_rate"] = args.explore_rate

        test_results = []
        for test_task in test_tasks:
            # if not test_task.demonstration:
            #     continue
            env = WebBotEnv(tasks=[test_task], browser=browser)
            env.replay()
            demo_tasklet = test_task.get_tasklet()
            logging.info(" demonstration:\n%s" % demo_tasklet)

            if args.algorithm == "dqn":
                agent = TestDQN(data_dir=args.data_dir,
                                log_dir=args.log_dir,
                                model_dir=args.model_dir,
                                supervised_model=supervised_model,
                                **model_param)
            elif args.algorithm == "hill_climbing":
                from baseline_agents import HillClimbing
                agent = HillClimbing(data_dir=args.data_dir,
                                     log_dir=args.log_dir,
                                     model_dir=args.model_dir,
                                     **model_param)
            elif args.algorithm == "monte_carlo":
                from baseline_agents import MonteCarlo
                agent = MonteCarlo(data_dir=args.data_dir,
                                   log_dir=args.log_dir,
                                   model_dir=args.model_dir,
                                   **model_param)
            elif args.algorithm == "random":
                model_param["explore_policy"] = "full_rand"
                model_param["explore_rate"] = 1.0
                agent = TestDQN(data_dir=args.data_dir,
                                log_dir=args.log_dir,
                                model_dir=args.model_dir,
                                supervised_model=supervised_model,
                                **model_param)
            else:
                agent = TestQTable(data_dir=args.data_dir,
                                   log_dir=args.log_dir,
                                   model_dir=args.model_dir,
                                   supervised_model=supervised_model,
                                   **model_param)

            found_tasklets = agent.train(tasks=[test_task], browser=browser)

            top_tasklets = Utils.top_n(found_tasklets, 10, reverse=True)
            post_processed_tasklets = {}
            for i, tasklet in enumerate(top_tasklets):
                original_total_reward, episode, original_final_screen = found_tasklets[
                    tasklet]
                tasklet, total_reward, final_screen = env.post_process(
                    test_task, tasklet)
                if total_reward > original_total_reward:
                    logging.debug("post-processing improved the total reward.")
                post_processed_tasklets[tasklet] = (total_reward, episode,
                                                    final_screen)
            top_tasklets = Utils.top_n(post_processed_tasklets,
                                       5,
                                       reverse=True)
            test_results.append(top_tasklets)

            task_output_dir = None
            if output_dir:
                task_output_dir = os.path.join(output_dir, test_task.name)
                if not os.path.exists(task_output_dir):
                    os.mkdir(task_output_dir)
            test_report = "\n" + "=" * 50 + "\n demonstration:\n%s\n" % demo_tasklet
            for i, tasklet in enumerate(top_tasklets):
                total_reward, episode, final_screen = post_processed_tasklets[
                    tasklet]
                test_report += "-" * 50 + "\n tasklet-%d (episode %d):\n%s\n" % (
                    i, episode, tasklet)
                if task_output_dir:
                    try:
                        final_screen_path = os.path.join(
                            task_output_dir, "final_state_tasklet-%d.png" % i)
                        final_screen.save(final_screen_path)
                    except:
                        pass
            logging.info(test_report)
            if task_output_dir:
                try:
                    task_report_path = os.path.join(task_output_dir,
                                                    "report.txt")
                    with open(task_report_path, "w") as f:
                        f.write(test_report)
                except:
                    pass

        result_lines = []
        success_ids = []
        for i in range(len(test_results)):
            top_tasklets = test_results[i]
            test_task = test_tasks[i]
            succeed_id = -1
            for i, tasklet in enumerate(top_tasklets):
                if "success:Y" in tasklet:
                    succeed_id = i
            success_ids.append(succeed_id)
            result_lines.append(
                "success:{:s} tasklet:{:d} #episodes:{:3d}  #demo_steps:{:2d}  task:{:60s} {:s}"
                .format(
                    "Y" if succeed_id >= 0 else "N",
                    succeed_id,
                    -1,  # TODO get the episode number
                    len(test_task.demonstration),
                    test_task.name,
                    test_task.task_str))

        success_count = sum([(1 if succeed_id > 0 else 0)
                             for succeed_id in success_ids])
        success_rate = (float(success_count) /
                        len(test_results)) if len(test_results) > 0 else 0
        result_lines.append("Success rate: %.3f" % success_rate)
        overall_report = "Result:\n" + "\n".join(result_lines)
        logging.info(overall_report)
        if output_dir:
            try:
                overall_report_path = os.path.join(output_dir, "overall.txt")
                with open(overall_report_path, "w") as f:
                    f.write(overall_report)
            except:
                pass

    if test_tasks and "analyze_task_complexity" in args.phases:
        browser = ChromeBrowser(wait=args.wait,
                                headless=args.headless,
                                proxy=args.proxy,
                                restart_reset=args.restart_reset,
                                chrome_path=args.chrome_path,
                                extension_path=args.extension_path)
        for test_task in test_tasks:
            env = WebBotEnv(browser=browser, tasks=[test_task])
            env.analyze_task_complexity(test_task, test_task.demonstration)

    if test_tasks and "build_cache" in args.phases:
        from environment import UTG
        browser = ChromeBrowser(wait=args.wait,
                                headless=args.headless,
                                proxy=args.proxy,
                                restart_reset=args.restart_reset,
                                chrome_path=args.chrome_path,
                                extension_path=args.extension_path)
        for test_task in test_tasks:
            task_category = test_task.name.split("_")[0]
            task_host = Utils.get_host(test_task.start_url)
            logging.info("building cache for %s" % test_task.task_str)
            env = WebBotEnv(tasks=[test_task], browser=browser)

            # save UTG to dir
            utg_dir_path = os.path.join(
                "output", "utgs", "%s_%s_utg" % (task_category, task_host))
            test_task.utg = UTG.load_from_dir(utg_dir_path)
            test_task.utg.save_states = True
            test_task.utg.start_url = test_task.start_url

            logging.info("replaying the demonstration")
            env.replay()
            test_task.utg.save()

            logging.info("exploring other paths: demo_biased strategy")
            env.explore(policy="demo_biased", n_episodes=50)
            test_task.utg.save()

            logging.info("exploring other paths: similarity strategy")
            env.explore(policy="similarity", n_episodes=50)
            test_task.utg.save()

            logging.info("exploring other paths: random strategy")
            env.explore(policy="random", n_episodes=50)
            test_task.utg.save()

            utg_zip_path = os.path.join(
                "output", "utg_zips", "%s_%s_utg" % (task_category, task_host))
            if not os.path.exists(os.path.dirname(utg_zip_path)):
                os.makedirs(os.path.dirname(utg_zip_path))
            shutil.make_archive(base_name=utg_zip_path,
                                format="zip",
                                root_dir=utg_dir_path)

            logging.info("done building cache for %s" % test_task.task_str)
            logging.info("UTG saved to %s" % utg_zip_path)
            # pv_utg_zip_dir = os.path.join("output", "pv_utg_zips")
            # shutil.copy(utg_zip_path, pv_utg_zip_dir)
            env.destroy()
    def train(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        stats = []

        def save_progress(save_stats=True,
                          save_fig=True,
                          save_model=False,
                          save_memory=False):
            try:
                if save_stats:
                    stats_path = os.path.join(self.model_dir,
                                              "training_stats.json")
                    json.dump(stats, open(stats_path, "w"), indent=2)
                if save_fig:
                    stats_png_path = os.path.join(self.log_dir,
                                                  "training_stats.png")
                    self._plot_training_stats(stats,
                                              self.et.n_explore_episodes,
                                              stats_png_path)
                if save_model:
                    self.save_model()
                if save_memory:
                    self.replay_memory.save(self.model_dir)
            except Exception as e:
                self.logger.warning(e)

        def resume_progress():
            # resume model
            self.load_model()
            # resume memory
            self.replay_memory.load(self.model_dir)
            # resume stats
            stats_path = os.path.join(self.model_dir, "training_stats.json")
            if os.path.exists(stats_path):
                stats.append(json.load(open(stats_path)))

        if self.resume:
            resume_progress()

        if self.demo_dir:
            self.demo_memory.load(self.demo_dir)
            for task in tasks:
                self.demo_memory.update_rewards(task)
            for i in range(self.demo_pretrain_steps):
                self._learn(memory_source="demo")
            self.logger.info("Done pre-training on demos.")

        found_tasklets = {}
        for episode in range(1, self.n_episodes + 1):
            # initial observation
            env.reset()
            task = env.current_task.snapshot()
            self.logger.info("Episode %d/%d, task: %s" %
                             (episode, self.n_episodes, task.task_str))

            max_reward = 0
            while True:
                # break while loop when end of this episode
                if task.done or task.reward < -10:
                    break
                env.render()
                epsilon = self.et.get_epsilon(episode, task)

                # RL choose action based on current task snapshot
                if np.random.uniform() < epsilon:
                    action_type = "Explore"
                    action = self.et.choose_action_to_explore(task)
                else:
                    action_type = "Exploit"
                    action, q = self.choose_action_with_model(
                        task, q_func=self.q_eval)
                env.step(action)

                # self.fe.plot_feature(task, action)
                task_ = env.current_task.snapshot()
                self.replay_memory.store_transition(
                    Transition(task=task, action=action, task_=task_))
                # swap observation
                task = task_
                self.logger.info(
                    "\t%s, epsilon:%.3f, action:%s, %s" %
                    (action_type, epsilon, action, task.get_reward_str()))

                tasklet = task.get_tasklet()
                if tasklet not in found_tasklets:
                    found_tasklets[tasklet] = (task.total_reward, episode,
                                               task.state.screenshot)
                if task.total_reward > max_reward:
                    max_reward = task.total_reward

            if episode > self.et.n_explore_episodes:
                max_q, q_error = self._learn()
            else:
                max_q, q_error = None, None
            epsilon = self.et.get_epsilon(episode=episode)
            stats.append([episode, epsilon, max_reward, max_q, q_error])
            self.logger.info(
                "Episode %d/%d, epsilon %.3f, max_reward %.2f, max_q %.3f, q_error %.3f"
                % (episode, self.n_episodes, epsilon, max_reward, max_q
                   or np.nan, q_error or np.nan))
            if episode % self.n_backup_episodes == 0:
                save_progress(save_fig=True,
                              save_model=False,
                              save_memory=False)
        save_progress(save_fig=True, save_model=True, save_memory=False)
        env.destroy()
        return found_tasklets
    def train(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        stats = []

        def save_progress(save_stats=True, save_fig=True, save_model=True):
            try:
                if save_stats:
                    stats_path = os.path.join(self.model_dir,
                                              "training_stats.json")
                    json.dump(stats, open(stats_path, "w"), indent=2)
                if save_fig:
                    stats_png_path = os.path.join(self.log_dir,
                                                  "training_stats.png")
                    self._plot_training_stats(stats, stats_png_path)
                if save_model:
                    self.save_model()
            except Exception as e:
                self.logger.warning(e)

        def resume_progress():
            # resume model
            self.load_model()
            stats_path = os.path.join(self.model_dir, "training_stats.json")
            if os.path.exists(stats_path):
                stats.append(json.load(open(stats_path)))

        if self.resume:
            resume_progress()

        found_tasklets = {}
        try:
            for episode in range(1, self.n_episodes + 1):
                env.reset()
                task = env.current_task.snapshot()
                self.logger.info("Episode %d/%d, task: %s" %
                                 (episode, self.n_episodes, task.task_str))

                max_reward = 0
                max_reward_task_snapshot = None
                tried_form_actions = []
                while True:
                    if task.done or task.reward < -10:
                        break
                    env.render()
                    epsilon = self.et.get_epsilon(episode, task)
                    # if episode == 1 and self.et.supervised_model and self.et.explore_policy == "supervised":
                    #     action_type = "Guided"
                    #     action = self.choose_action_with_supervised_model(task)
                    # el
                    candidate_actions = []
                    interacted_form_ids = [
                        form.unique_id for form, _ in tried_form_actions
                    ]
                    for candidate_action in self.get_candidate_actions(task):
                        if isinstance(candidate_action, FormAction) and \
                                candidate_action.form.unique_id in interacted_form_ids:
                            continue
                        candidate_actions.append(candidate_action)

                    if len(candidate_actions) == 0:
                        break
                    rand = np.random.uniform()
                    action_category, action, q = "Unknown", None, 0
                    if rand > epsilon:
                        action_category = "Exploit"
                        action, q = self.choose_action_with_model(
                            task, candidate_actions)
                    if rand <= epsilon or q == 0:
                        action_category = "Explore"
                        action = self.choose_action_to_explore(
                            task, candidate_actions)
                    # self.fe.plot_feature(task, action)
                    if action is None:
                        break

                    init_task = task.snapshot()
                    if isinstance(action, FormAction):
                        form = action.form
                        form_actions, action_categories = form.try_solve(
                            epsilon)
                        init_reward = task.total_reward
                        for i, form_action in enumerate(form_actions):
                            form_action_element = task.state.get_element_by_locator(
                                form_action.element.locator)
                            if form_action_element is None:
                                form_action.value = None
                            if form_action.value is None:
                                continue
                            env.step(form_action)
                            task = env.current_task.snapshot()
                            self.logger.info(
                                "\t%s, epsilon:%.3f, action:%s, %s" %
                                (action_categories[i], epsilon, form_action,
                                 task.get_reward_str()))
                        tried_form_actions.append((form, form_actions))
                        self.logger.info(
                            f" {action} achieved {task.total_reward - init_reward:.2f}"
                        )
                    else:
                        env.step(action)
                        task = env.current_task.snapshot()
                        self.logger.info("\t%s, epsilon:%.3f, action:%s, %s" %
                                         (action_category, epsilon, action,
                                          task.get_reward_str()))
                    self._learn(init_task, action, task)

                    if task.total_reward > max_reward:
                        max_reward = task.total_reward
                        max_reward_task_snapshot = task.snapshot()

                if max_reward_task_snapshot is not None:
                    max_reward_tasklet = max_reward_task_snapshot.get_tasklet()
                    if max_reward_tasklet not in found_tasklets:
                        found_tasklets[max_reward_tasklet] = \
                            (max_reward_task_snapshot.total_reward, episode, max_reward_task_snapshot.state.screenshot)

                # learn form
                for (form, form_actions) in tried_form_actions:
                    form.store_actions_actual_reward(form_actions, max_reward)
                self.form_manager.learn()

                epsilon = self.et.get_epsilon(episode=episode)
                stats.append([episode, epsilon, max_reward])
                if episode % self.n_backup_episodes == 0:
                    save_progress(save_fig=True, save_model=False)
                self.logger.info(
                    "Episode %d/%d, epsilon %.3f, max_reward %.2f" %
                    (episode, self.n_episodes, epsilon, max_reward))
            save_progress(save_fig=True, save_model=True)
            env.destroy()
        except Exception as e:
            import traceback
            traceback.print_exc()
            self.logger.info("failed with error: %s" % e)
        return found_tasklets