示例#1
0
    def get_action_vector(self, action, task):
        action_vecs = []
        if self.use_action_type:
            action_type_vec = np.zeros(len(self.ACTION_TYPES))
            if action.action_type in self.ACTION_TYPE_MAPPING:
                action_type = self.ACTION_TYPE_MAPPING[action.action_type]
                action_type_index = self.ACTION_TYPES.index(action_type)
                action_type_vec[action_type_index] = 1.
            action_vecs.append(action_type_vec)
        if self.use_action_query_sim:
            action_sim_non_para = 0
            if action.is_input:
                action_sim_non_paras = []
                for non_para in task.non_parameters_parsed:
                    action_sim_non_paras.append(Utils.text_similarity(non_para, action.value_text_parsed))
                action_sim_non_para = np.mean(action_sim_non_paras)
            action_vecs.append(np.array([action_sim_non_para]))

            action_sim_para_val = np.zeros(self.n_para_dim)
            if action.is_input:
                for i, para_val in enumerate(task.parameter_values_parsed):
                    if i > self.n_para_dim:
                        break
                    action_sim_para_val[i] = Utils.text_similarity(para_val, action.value_text_parsed)
            action_vecs.append(action_sim_para_val)
        return np.concatenate(action_vecs)
    def extract_input_candidates(task):
        input_candidates = {}
        selectable_values = set()
        for action in task.state.possible_actions:
            action_type, action_ele, action_value = action.action_type, action.element, action.value
            if action_type not in [Action.INPUT_TEXT, Action.SELECT]:
                continue
            if (action_type, action_ele) not in input_candidates:
                input_candidates[(action_type, action_ele)] = [None]
            if action_value in input_candidates[(action_type, action_ele)]:
                continue
            if action_type == Action.SELECT:
                action_value_useful = False
                for word in task.all_words_parsed:
                    word_sim = Utils.text_similarity(word, action.value_text_parsed)
                    if word_sim > 0.5:
                        selectable_values.add(word)
                        action_value_useful = True
                if not action_value_useful:
                    continue
            input_candidates[(action_type, action_ele)].append(action_value)
        input_candidates = OrderedDict(sorted(input_candidates.items(), key=lambda x: x[0][1].id))

        for (action_type, action_ele) in input_candidates:
            values = input_candidates[(action_type, action_ele)]
            values_parsed = [
                None if value is None else Action(action_ele, action_type, value).value_text_parsed
                for value in values
            ]

            # keep the max-similarity value for each parameter
            filtered_values = [None]
            for word in task.all_words_parsed:
                if action_type == Action.INPUT_TEXT and word in selectable_values:
                    continue
                max_score = 0
                max_score_value = None
                for i, value_parsed in enumerate(values_parsed):
                    if value_parsed is None:
                        continue
                    value_score = Utils.text_similarity(word, value_parsed)
                    if value_score > max_score:
                        max_score = value_score
                        max_score_value = values[i]
                if max_score_value is not None and max_score_value not in filtered_values:
                    filtered_values.append(max_score_value)
            values = filtered_values

            values = sorted(values, key=lambda x: str(x))
            input_candidates[(action_type, action_ele)] = values
        return input_candidates
示例#3
0
    def train(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        stats = []

        found_tasklets = {}
        try:
            for episode in range(1, self.n_episodes + 1):
                env.reset()
                task = env.current_task.snapshot()
                self.logger.info("Episode %d/%d, task: %s" %
                                 (episode, self.n_episodes, task.task_str))
                max_reward = 0
                while True:
                    if task.done or task.reward < -10:
                        break
                    env.render()

                    actions = task.get_preferred_actions()
                    if len(actions) == 0:
                        actions = task.state.possible_actions
                    candidate_actions = []
                    action_scores = {}
                    for action in actions:
                        action_score = self._get_action_score(task, action)
                        action_scores[action] = action_score
                        if action_score > 0.1:
                            candidate_actions.append(action)
                    if len(candidate_actions) == 0:
                        candidate_actions = Utils.top_n(action_scores,
                                                        5,
                                                        reverse=True)
                    action = random.choice(candidate_actions)

                    env.step(action)
                    self._set_action_score(task, action, score=task.reward)

                    task_ = env.current_task.snapshot()
                    task = task_
                    self.logger.info("\taction:%s, %s" %
                                     (action, task.get_reward_str()))

                    tasklet = task.get_tasklet()
                    if tasklet not in found_tasklets:
                        found_tasklets[tasklet] = (task.total_reward, episode,
                                                   task.state.screenshot)
                    if task.total_reward > max_reward:
                        max_reward = task.total_reward

                stats.append([episode, max_reward])
                self.logger.info("Episode %d/%d, max_reward %.2f" %
                                 (episode, self.n_episodes, max_reward))
            env.destroy()
        except Exception as e:
            import traceback
            traceback.print_exc()
            self.logger.info("failed with error: %s" % e)
        return found_tasklets
示例#4
0
 def choose_action_to_explore(self, task, candidate_actions):
     if np.random.uniform() < 0.5:
         action_scores = {}
         for action in candidate_actions:
             action_scores[action] = len(action.form.input_candidates) \
                 if isinstance(action, FormAction) else task.get_action_usefulness(action)
         return Utils.weighted_choice(action_scores)
     else:
         return random.choice(candidate_actions)
    def __init__(self, task, input_candidates):
        self.task = task.snapshot()
        self.input_candidates = input_candidates
        self.input_candidates_strs = Form.get_input_candidates_strs(input_candidates)
        self.unique_id = Utils.md5("\n".join(self.input_candidates_strs))

        self.form_actor = FormActor(self)
        self.form_critic = FormCritic(self)

        self.tried_input_combs = {}
        self.best_input_comb, _ = self.generate_input_comb()
        self.best_reward = -np.inf
 def simulate_actions(self, actions):
     task = self.task.snapshot()
     init_reward = task.total_reward
     for action in actions:
         if action.value is None:
             continue
         fake_state = Utils.create_fake_state(current_state=task.state, action=action)
         task.state_history.append(task.state)
         task.action_history.append(action)
         task.state = fake_state
     task._evaluate()
     final_reward = task.total_reward
     return final_reward - init_reward
示例#7
0
    def choose_action_to_explore(self, task):
        preferred_actions = task.get_preferred_actions()
        if len(preferred_actions) == 0:
            return random.choice(task.state.possible_actions)

        if "supervised" in self.explore_policy:
            if np.random.uniform() < 0.5:
                action_probabilities = self.supervised_model.predict(
                    task, preferred_actions)
                return Utils.weighted_choice(action_probabilities)
            else:
                return np.random.choice(preferred_actions)
        elif "similarity" in self.explore_policy:
            if np.random.uniform() < 0.5:
                action_scores = {}
                for action in preferred_actions:
                    action_scores[action] = self._get_action_usefulness(
                        task, action)
                return Utils.weighted_choice(action_scores)
            else:
                return np.random.choice(preferred_actions)
        elif "full_sim" in self.explore_policy:
            action_scores = {}
            for action in task.state.possible_actions:
                action_scores[action] = self._get_action_usefulness(
                    task, action)
            return Utils.weighted_choice(action_scores)
        elif "full_rand" in self.explore_policy:
            return np.random.choice(task.state.possible_actions)
        elif "half_sim_half_rand" in self.explore_policy:
            if np.random.uniform() < 0.5:
                action_scores = {}
                for action in task.state.possible_actions:
                    action_scores[action] = self._get_action_usefulness(
                        task, action)
                return Utils.weighted_choice(action_scores)
            else:
                return np.random.choice(task.state.possible_actions)
        return np.random.choice(task.state.possible_actions)
示例#8
0
 def get_query_vector(self, task):
     query_vecs = []
     if self.use_query_embed:
         non_para_embeds = []
         for non_para in task.non_parameters_parsed:
             non_para_embed = Utils.vec(non_para)[:self.text_embed_dim]
             non_para_embeds.append(non_para_embed)
         query_vecs.append(np.mean(non_para_embeds, axis=0))
         for i in range(self.n_para_dim):
             para_embed = np.zeros(self.text_embed_dim)
             if i < len(task.parameter_annotations_parsed):
                 para_annos = task.parameter_annotations_parsed[i]
                 para_embed = np.mean([Utils.vec(anno)[:self.text_embed_dim] for anno in para_annos], axis=0)
             query_vecs.append(para_embed)
     if self.use_query_score:
         sim_non_para, sim_paras = task.query_achieved_scores()
         query_score = np.zeros(1 + self.n_para_dim)
         query_score[0] = sim_non_para
         for i, sim_para in enumerate(sim_paras):
             if i + 1 < len(query_score):
                 query_score[i+1] = sim_para
         query_vecs.append(query_score)
     return np.concatenate(query_vecs)
示例#9
0
 def explore_task(n_episodes, policy, states_dir=None):
     for episode in range(1, n_episodes + 1):
         env.reset()
         task = env.current_task.snapshot()
         target_achieved = False
         while True:
             if task.done:
                 break
             env.render()
             preferred_actions = env.browser._filter_actions(
                 task.state, task.get_preferred_actions())
             possible_actions = env.browser._filter_actions(
                 task.state, task.state.possible_actions)
             if len(possible_actions) == 0:
                 break
             if len(preferred_actions) == 0:
                 preferred_actions = possible_actions
             if policy == "similarity":
                 action_scores = {}
                 for action in preferred_actions:
                     action_scores[action] = task.get_action_usefulness(
                         action)
                 action = Utils.weighted_choice(action_scores)
             elif policy == "random":
                 action = random.choice(possible_actions)
             elif policy == "random_restricted":
                 action = random.choice(preferred_actions)
             elif policy == "demo_biased":
                 actions_in_demo = []
                 for action in preferred_actions:
                     if action.replay_api in task.demonstration:
                         actions_in_demo.append(action)
                 rand = random.uniform(0, 1)
                 if len(actions_in_demo) > 0 and rand < 0.5:
                     action = random.choice(actions_in_demo)
                 else:
                     action = random.choice(preferred_actions)
             else:
                 action = random.choice(possible_actions)
             env.step(action)
             if task.target_achieved:
                 target_achieved = True
             task_ = env.current_task.snapshot()
             task = task_
         replay_trace = get_replay_trace(task, states_dir)
         if target_achieved:
             correct_traces.append(replay_trace)
         if not target_achieved:
             incorrect_traces.append(replay_trace)
示例#10
0
 def _get_text_similarity_with_words(self, text, words, n_dim=1, sub_scores=None):
     if not sub_scores:
         sub_scores = [0] * len(words)
     word_similarities = [max(Utils.text_similarity(word, text) - sub_scores[i], 0)
                          for i, word in enumerate(words)]
     if n_dim > 1:
         similarity_array = np.zeros(n_dim)
         if len(words) > 0 and len(text) > 0:
             for i in range(n_dim):
                 if i >= len(words):
                     break
                 similarity_array[i] = word_similarities[i]
         return similarity_array
     else:
         return np.max(word_similarities)
 def generate_input_comb(self, epsilon=1.0, eval_func=None):
     input_comb = OrderedDict()
     action_categories = []
     previous_actions = []
     for action_type, action_ele in self.input_candidates:
         previous_values = [
             None if previous_action.value is None else previous_action.value_text_parsed
             for previous_action in previous_actions
         ]
         value_candidates = []
         for value in self.input_candidates[(action_type, action_ele)]:
             if value is not None:
                 value_text = Action(action_ele, action_type, value).value_text_parsed
                 if value_text in previous_values:
                     continue
             value_candidates.append(value)
         if np.random.uniform() <= epsilon:
             action_category = "Explore"
             if np.random.uniform() <= 0.5:
                 action_value = random.choice(value_candidates)
             else:
                 greedy_input_comb_rewards = self.greedy_input_comb_rewards[(action_type, action_ele)]
                 value_candidate_rewards = {
                     value_candidate: greedy_input_comb_rewards[value_candidate]
                     for value_candidate in value_candidates
                 }
                 action_value = Utils.weighted_choice(value_candidate_rewards)
         else:
             action_category = "Exploit"
             max_q_score = -1
             best_value = random.choice(value_candidates)
             for value_candidate in value_candidates:
                 action_candidate = Action(element=action_ele, action_type=action_type, value=value_candidate)
                 q_score = eval_func(self, previous_actions, action_candidate)
                 if q_score > max_q_score:
                     max_q_score = q_score
                     best_value = value_candidate
             action_value = best_value
         previous_actions.append(Action(element=action_ele, action_type=action_type, value=action_value))
         input_comb[(action_type, action_ele)] = action_value
         action_categories.append(action_category)
     return input_comb, action_categories
示例#12
0
    def execute(self, tasks, browser):
        env = WebBotEnv(tasks=tasks, browser=browser)
        for task in tasks:
            env.reset(new_task=task)
            task = env.current_task.snapshot()

            while True:
                if task.done:
                    break
                env.render()

                actions = task.get_preferred_actions()
                action2p = self.model.predict(task, actions)
                action = Utils.weighted_choice(action2p)
                env.step(action)
                task_ = env.current_task.snapshot()
                task = task_
                self.logger.info("\tExploit, action:%s, reward:%.2f, done:%s" %
                                 (action, task.reward, task.done))
            self.logger.info("Got total_reward %.2f in task." %
                             task.total_reward)
        env.destroy()
        self.logger.info("Done testing tasks.")
 def compute_feature(element, text_parsed):
     if not element.own_text_parsed or not text_parsed:
         return None
     return Utils.text_similarity(text_parsed, element.own_text_parsed)
示例#14
0
 def compute_feature(element):
     if element.own_text_parsed:
         return Utils.vec(element.own_text_parsed)[:self.text_embed_dim]
     else:
         return None
示例#15
0
def train_and_test(train_tasks,
                   test_tasks,
                   args,
                   model_param,
                   output_dir=None):
    from agent import TestQTable, TestDQN
    from model import Qda2pModel
    from environment import Task
    from environment import ChromeBrowser, CacheBrowser, WebBotEnv, Utils, UTG

    # Training on other tasks is currently not supported
    # logging.info("Train on %d tasks:\n" % len(train_tasks) + "\n".join([task.task_str for task in train_tasks]))
    logging.info("Test on %d tasks:\n" % len(test_tasks) +
                 "\n".join(task.task_str for task in test_tasks))

    if args.disable_text_embed:
        model_param["disable_text_embed"] = args.disable_text_embed

    supervised_model = None
    if train_tasks and "train" in args.phases:
        if len(train_tasks) > args.num_train_tasks > 0:
            train_tasks = random.sample(train_tasks, args.num_train_tasks)
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        n_train_episodes = model_param[
            "n_train_episodes"] if "n_train_episodes" in model_param else 200
        model_param["n_episodes"] = n_train_episodes

        supervised_model = Qda2pModel(data_dir=args.data_dir,
                                      log_dir=args.log_dir,
                                      model_dir=args.model_dir,
                                      **model_param)
        supervised_model.train(tasks=train_tasks, browser=browser)
        browser.close()
    if test_tasks and "replay" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        for test_task in test_tasks:
            env = WebBotEnv(tasks=[test_task], browser=browser)
            env.replay()
            replay_tasklet = test_task.get_tasklet()
            logging.info(" replay finished:\n%s" % replay_tasklet)

    if "eval_reward" in args.phases:
        from environment import State, Action, ACTION_RE
        import traceback
        import numpy as np

        # utgs = UTG.load_utgs_from_dir("output/pv_utg_zips")
        # browser = CacheBrowser(utgs=utgs)

        trace_dir = args.task_dir
        states_dir = os.path.join(trace_dir, "states")
        results_dir = os.path.join(trace_dir, "results")
        if not os.path.exists(states_dir):
            states_dir = os.path.join(trace_dir, "states.zip")
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

        def load_action(state, action_line):
            m = re.match(ACTION_RE, action_line)
            action_type, value, target_locator = m.group(1), m.group(
                2), m.group(3)
            target_ele = state.get_element_by_locator(target_locator)
            action = Action(element=target_ele,
                            action_type=action_type,
                            value=value)
            return action

        def compute_reward(task, trace_lines):
            # logging.info(f"compute_reward starts at {datetime.now()}")
            states = []
            actions = []
            # browser.reset(task.start_url)
            state_action_lines = [(line[:(line.find(": "))],
                                   line[(line.find(": ") + 2):])
                                  for line in trace_lines]
            current_state_str, action_line = state_action_lines[0]
            current_state = State.load(states_dir, current_state_str)
            actions.append("RESET")
            states.append(current_state)
            task.reset(current_state, update_utg=False)
            last_action = load_action(current_state, action_line)
            actions.append(action_line)
            end_reached = False
            correct_rewards = [0]
            incorrect_rewards = [task.total_reward]
            for state_str, action_line in state_action_lines[1:]:
                current_state = State.load(states_dir, state_str)
                states.append(current_state)
                task.update(last_action, current_state, update_utg=False)
                if task.target_achieved:
                    correct_rewards.append(task.total_reward)
                else:
                    incorrect_rewards.append(task.total_reward)
                if action_line == "END":
                    end_reached = True
                    break
                else:
                    last_action = load_action(current_state, action_line)
            max_correct_reward = max(correct_rewards)
            max_incorrect_reward = max(incorrect_rewards)
            logging.info(
                f"  task got correct reward {max_correct_reward:6.3f}"
                f" and incorrect reward {max_incorrect_reward:3.3f}: {task.name}"
            )
            return max_correct_reward, max_incorrect_reward

        def compute_rewards(task, traces):
            correct_rewards = []
            incorrect_rewards = []
            new_traces = []
            for trace in traces:
                if len(trace) == 0:
                    continue
                try:
                    correct_reward, incorrect_reward = compute_reward(
                        task, trace.splitlines())
                    correct_rewards.append(correct_reward)
                    incorrect_rewards.append(incorrect_reward)
                except:
                    # traceback.print_exc()
                    # logging.warning(f"compute_reward failed for {task.task_str}:\n{trace}")
                    pass
            return correct_rewards, incorrect_rewards

        all_traces = {}
        for task_name in os.listdir(trace_dir):
            if not task_name.endswith(".json"):
                continue
            task_file_path = os.path.join(trace_dir, task_name)
            trace_file = open(task_file_path)
            try:
                trace_dict = json.load(trace_file)
            except:
                logging.warning(f"unable to load task: {task_name}")
                continue
            correct_traces = trace_dict["correct_traces"]
            incorrect_traces = trace_dict["incorrect_traces"]
            if len(correct_traces) > 3:
                trace2len = {
                    trace: len(trace.splitlines())
                    for trace in correct_traces
                }
                correct_traces = Utils.top_n(trace2len, 3)
            num_incorrect_traces = len(correct_traces) * 10
            if len(incorrect_traces) > num_incorrect_traces:
                incorrect_traces = incorrect_traces[:num_incorrect_traces]

            correct_trace_str = "\n\n".join(correct_traces)
            filtered_incorrect_traces = []
            for trace in incorrect_traces:
                if trace not in correct_trace_str:
                    filtered_incorrect_traces.append(trace)
            incorrect_traces = filtered_incorrect_traces

            if len(correct_traces) > 0 and len(incorrect_traces) > 0:
                all_traces[task_name] = {}
                all_traces[task_name]["correct_traces"] = correct_traces
                all_traces[task_name]["incorrect_traces"] = incorrect_traces

        def get_reward_results(tag):
            logging.info(
                f"evaluating reward function for configuration: {tag}")
            all_rewards = {}
            for task_name in all_traces:
                logging.info(f" computing reward for {task_name}")
                task_file_path = os.path.join(trace_dir, task_name)
                test_task = Task.load(task_file_path)
                correct_traces = all_traces[task_name]["correct_traces"]
                incorrect_traces = all_traces[task_name]["incorrect_traces"]
                if len(correct_traces) == 0 or len(incorrect_traces) == 0:
                    continue
                correct_rewards, incorrect_rewards = compute_rewards(
                    test_task, correct_traces + incorrect_traces)
                if len(correct_rewards) > 0 and len(incorrect_rewards) > 0:
                    all_rewards[task_name] = {}
                    all_rewards[task_name]["correct_rewards"] = correct_rewards
                    all_rewards[task_name][
                        "incorrect_rewards"] = incorrect_rewards

                    logging.info(test_task.task_str)
                    logging.info(f"correct: {len(correct_traces)} "
                                 f"max: {np.max(correct_rewards)}")
                    logging.info(f"incorrect: {len(incorrect_traces)} "
                                 f"max: {np.max(incorrect_rewards)}")
            result_file_path = os.path.join(results_dir, f"{tag}.json")
            with open(result_file_path, "w") as f:
                json.dump(all_rewards, f, indent=2)

        get_reward_results(args.reward_weights)

    if test_tasks and "collect_pv_traces" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)

        pv_trace_dir = os.path.join("output", "pv_traces")
        pv_states_dir = os.path.join(pv_trace_dir, "states")
        if not os.path.exists(pv_trace_dir):
            os.mkdir(pv_trace_dir)
        if not os.path.exists(pv_states_dir):
            os.mkdir(pv_states_dir)

        def get_replay_trace(task, states_dir=None):
            replay_trace = []
            for i, action in enumerate(task.action_history):
                state = task.state_history[i]
                replay_trace.append(f"{state.state_str}: {action.replay_api}")
                if states_dir:
                    state.save(states_dir)
            replay_trace.append(f"{task.state.state_str}: END")
            if states_dir:
                task.state.save(states_dir)
            return "\n".join(replay_trace)

        def explore_task(n_episodes, policy, states_dir=None):
            for episode in range(1, n_episodes + 1):
                env.reset()
                task = env.current_task.snapshot()
                target_achieved = False
                while True:
                    if task.done:
                        break
                    env.render()
                    preferred_actions = env.browser._filter_actions(
                        task.state, task.get_preferred_actions())
                    possible_actions = env.browser._filter_actions(
                        task.state, task.state.possible_actions)
                    if len(possible_actions) == 0:
                        break
                    if len(preferred_actions) == 0:
                        preferred_actions = possible_actions
                    if policy == "similarity":
                        action_scores = {}
                        for action in preferred_actions:
                            action_scores[action] = task.get_action_usefulness(
                                action)
                        action = Utils.weighted_choice(action_scores)
                    elif policy == "random":
                        action = random.choice(possible_actions)
                    elif policy == "random_restricted":
                        action = random.choice(preferred_actions)
                    elif policy == "demo_biased":
                        actions_in_demo = []
                        for action in preferred_actions:
                            if action.replay_api in task.demonstration:
                                actions_in_demo.append(action)
                        rand = random.uniform(0, 1)
                        if len(actions_in_demo) > 0 and rand < 0.5:
                            action = random.choice(actions_in_demo)
                        else:
                            action = random.choice(preferred_actions)
                    else:
                        action = random.choice(possible_actions)
                    env.step(action)
                    if task.target_achieved:
                        target_achieved = True
                    task_ = env.current_task.snapshot()
                    task = task_
                replay_trace = get_replay_trace(task, states_dir)
                if target_achieved:
                    correct_traces.append(replay_trace)
                if not target_achieved:
                    incorrect_traces.append(replay_trace)

        task_traces = {}
        for test_task in test_tasks:
            assert isinstance(test_task, Task)
            if not test_task.replayable:
                continue
            task_file_path = os.path.join(pv_trace_dir,
                                          test_task.name + "task.json")
            if os.path.exists(task_file_path):
                continue
            try:
                correct_traces = []
                incorrect_traces = []

                env = WebBotEnv(tasks=[test_task], browser=browser)
                env.replay()
                if test_task.target_achieved:
                    correct_traces.append(
                        get_replay_trace(test_task, pv_states_dir))
                explore_task(10, "similarity", pv_states_dir)
                explore_task(10, "random", pv_states_dir)
                explore_task(10, "random_restricted", pv_states_dir)
                explore_task(10, "demo_biased", pv_states_dir)
                GLOBAL_CONFIGS["semantic_similarity"] = True
                explore_task(10, "similarity", pv_states_dir)
                GLOBAL_CONFIGS["semantic_similarity"] = False
                correct_traces = list(set(correct_traces))
                incorrect_traces = list(set(incorrect_traces))
                task_traces[test_task] = [correct_traces, incorrect_traces]
                task_dict = test_task.to_dict(as_demo=True)
                task_dict["correct_traces"] = correct_traces
                task_dict["incorrect_traces"] = incorrect_traces
                with open(task_file_path, "w") as task_file:
                    json.dump(task_dict, task_file, indent=2)
                logging.info(
                    f" collected {len(correct_traces)} correct traces and"
                    f" {len(incorrect_traces)} incorrect traces"
                    f" in task {test_task.name}")
            except:
                logging.info(
                    f" failed to collect traces in task {test_task.name}")

    if test_tasks and "crawl" in args.phases:
        if args.use_cache:
            utgs = UTG.load_utgs_from_dir("output/utg_zips")
            browser = CacheBrowser(utgs=utgs)
        else:
            browser = ChromeBrowser(wait=args.wait,
                                    headless=args.headless,
                                    proxy=args.proxy,
                                    restart_reset=args.restart_reset,
                                    chrome_path=args.chrome_path,
                                    extension_path=args.extension_path)
        n_test_episodes = model_param[
            "n_test_episodes"] if "n_test_episodes" in model_param else 100

        model_param["n_episodes"] = n_test_episodes
        if args.explore_policy:
            model_param["explore_policy"] = args.explore_policy
        if args.explore_rate:
            model_param["explore_rate"] = args.explore_rate

        test_results = []
        for test_task in test_tasks:
            # if not test_task.demonstration:
            #     continue
            env = WebBotEnv(tasks=[test_task], browser=browser)
            env.replay()
            demo_tasklet = test_task.get_tasklet()
            logging.info(" demonstration:\n%s" % demo_tasklet)

            if args.algorithm == "dqn":
                agent = TestDQN(data_dir=args.data_dir,
                                log_dir=args.log_dir,
                                model_dir=args.model_dir,
                                supervised_model=supervised_model,
                                **model_param)
            elif args.algorithm == "hill_climbing":
                from baseline_agents import HillClimbing
                agent = HillClimbing(data_dir=args.data_dir,
                                     log_dir=args.log_dir,
                                     model_dir=args.model_dir,
                                     **model_param)
            elif args.algorithm == "monte_carlo":
                from baseline_agents import MonteCarlo
                agent = MonteCarlo(data_dir=args.data_dir,
                                   log_dir=args.log_dir,
                                   model_dir=args.model_dir,
                                   **model_param)
            elif args.algorithm == "random":
                model_param["explore_policy"] = "full_rand"
                model_param["explore_rate"] = 1.0
                agent = TestDQN(data_dir=args.data_dir,
                                log_dir=args.log_dir,
                                model_dir=args.model_dir,
                                supervised_model=supervised_model,
                                **model_param)
            else:
                agent = TestQTable(data_dir=args.data_dir,
                                   log_dir=args.log_dir,
                                   model_dir=args.model_dir,
                                   supervised_model=supervised_model,
                                   **model_param)

            found_tasklets = agent.train(tasks=[test_task], browser=browser)

            top_tasklets = Utils.top_n(found_tasklets, 10, reverse=True)
            post_processed_tasklets = {}
            for i, tasklet in enumerate(top_tasklets):
                original_total_reward, episode, original_final_screen = found_tasklets[
                    tasklet]
                tasklet, total_reward, final_screen = env.post_process(
                    test_task, tasklet)
                if total_reward > original_total_reward:
                    logging.debug("post-processing improved the total reward.")
                post_processed_tasklets[tasklet] = (total_reward, episode,
                                                    final_screen)
            top_tasklets = Utils.top_n(post_processed_tasklets,
                                       5,
                                       reverse=True)
            test_results.append(top_tasklets)

            task_output_dir = None
            if output_dir:
                task_output_dir = os.path.join(output_dir, test_task.name)
                if not os.path.exists(task_output_dir):
                    os.mkdir(task_output_dir)
            test_report = "\n" + "=" * 50 + "\n demonstration:\n%s\n" % demo_tasklet
            for i, tasklet in enumerate(top_tasklets):
                total_reward, episode, final_screen = post_processed_tasklets[
                    tasklet]
                test_report += "-" * 50 + "\n tasklet-%d (episode %d):\n%s\n" % (
                    i, episode, tasklet)
                if task_output_dir:
                    try:
                        final_screen_path = os.path.join(
                            task_output_dir, "final_state_tasklet-%d.png" % i)
                        final_screen.save(final_screen_path)
                    except:
                        pass
            logging.info(test_report)
            if task_output_dir:
                try:
                    task_report_path = os.path.join(task_output_dir,
                                                    "report.txt")
                    with open(task_report_path, "w") as f:
                        f.write(test_report)
                except:
                    pass

        result_lines = []
        success_ids = []
        for i in range(len(test_results)):
            top_tasklets = test_results[i]
            test_task = test_tasks[i]
            succeed_id = -1
            for i, tasklet in enumerate(top_tasklets):
                if "success:Y" in tasklet:
                    succeed_id = i
            success_ids.append(succeed_id)
            result_lines.append(
                "success:{:s} tasklet:{:d} #episodes:{:3d}  #demo_steps:{:2d}  task:{:60s} {:s}"
                .format(
                    "Y" if succeed_id >= 0 else "N",
                    succeed_id,
                    -1,  # TODO get the episode number
                    len(test_task.demonstration),
                    test_task.name,
                    test_task.task_str))

        success_count = sum([(1 if succeed_id > 0 else 0)
                             for succeed_id in success_ids])
        success_rate = (float(success_count) /
                        len(test_results)) if len(test_results) > 0 else 0
        result_lines.append("Success rate: %.3f" % success_rate)
        overall_report = "Result:\n" + "\n".join(result_lines)
        logging.info(overall_report)
        if output_dir:
            try:
                overall_report_path = os.path.join(output_dir, "overall.txt")
                with open(overall_report_path, "w") as f:
                    f.write(overall_report)
            except:
                pass

    if test_tasks and "analyze_task_complexity" in args.phases:
        browser = ChromeBrowser(wait=args.wait,
                                headless=args.headless,
                                proxy=args.proxy,
                                restart_reset=args.restart_reset,
                                chrome_path=args.chrome_path,
                                extension_path=args.extension_path)
        for test_task in test_tasks:
            env = WebBotEnv(browser=browser, tasks=[test_task])
            env.analyze_task_complexity(test_task, test_task.demonstration)

    if test_tasks and "build_cache" in args.phases:
        from environment import UTG
        browser = ChromeBrowser(wait=args.wait,
                                headless=args.headless,
                                proxy=args.proxy,
                                restart_reset=args.restart_reset,
                                chrome_path=args.chrome_path,
                                extension_path=args.extension_path)
        for test_task in test_tasks:
            task_category = test_task.name.split("_")[0]
            task_host = Utils.get_host(test_task.start_url)
            logging.info("building cache for %s" % test_task.task_str)
            env = WebBotEnv(tasks=[test_task], browser=browser)

            # save UTG to dir
            utg_dir_path = os.path.join(
                "output", "utgs", "%s_%s_utg" % (task_category, task_host))
            test_task.utg = UTG.load_from_dir(utg_dir_path)
            test_task.utg.save_states = True
            test_task.utg.start_url = test_task.start_url

            logging.info("replaying the demonstration")
            env.replay()
            test_task.utg.save()

            logging.info("exploring other paths: demo_biased strategy")
            env.explore(policy="demo_biased", n_episodes=50)
            test_task.utg.save()

            logging.info("exploring other paths: similarity strategy")
            env.explore(policy="similarity", n_episodes=50)
            test_task.utg.save()

            logging.info("exploring other paths: random strategy")
            env.explore(policy="random", n_episodes=50)
            test_task.utg.save()

            utg_zip_path = os.path.join(
                "output", "utg_zips", "%s_%s_utg" % (task_category, task_host))
            if not os.path.exists(os.path.dirname(utg_zip_path)):
                os.makedirs(os.path.dirname(utg_zip_path))
            shutil.make_archive(base_name=utg_zip_path,
                                format="zip",
                                root_dir=utg_dir_path)

            logging.info("done building cache for %s" % test_task.task_str)
            logging.info("UTG saved to %s" % utg_zip_path)
            # pv_utg_zip_dir = os.path.join("output", "pv_utg_zips")
            # shutil.copy(utg_zip_path, pv_utg_zip_dir)
            env.destroy()