Ejemplo n.º 1
0
def replay_check(args, thread_num=0):
    env = ThorEnv(x_display='0.%d' % (thread_num % args.total_gpu))

    # replay certificate filenames
    replay_certificate_filenames = [
        "replay.certificate.%d" % idx for idx in range(args.num_replays)
    ]

    # Clear existing failures in file recording.
    if args.failure_filename is not None:
        with open(args.failure_filename, 'w') as f:
            f.write('')

    continue_check = True
    total_checks, total_failures, crash_fails, unsat_fails, json_fails, nondet_fails = 0, 0, 0, 0, 0, 0
    errors = {
    }  # map from error strings to counts, to be shown after every failure.
    total_threads = args.total_gpu * args.num_threads
    current_threads = args.gpu_id * args.num_threads + thread_num

    while continue_check:

        # Crawl the directory of trajectories and vet ones with no certificate.
        failure_list = []
        valid_dirs = []
        count = 0
        for dir_name, subdir_list, file_list in os.walk(args.data_path):
            if "trial_" in dir_name and (not "raw_images" in dir_name) and (
                    not "pddl_states" in dir_name):
                json_file = os.path.join(dir_name, JSON_FILENAME)
                if not os.path.isfile(json_file):
                    continue

                # If we're just stripping certificates, do that and continue.
                if args.remove_certificates:
                    for cidx in range(args.num_replays):
                        certificate_file = os.path.join(
                            dir_name, replay_certificate_filenames[cidx])
                        if os.path.isfile(certificate_file):
                            os.system("rm %s" % certificate_file)
                    continue

                if count % total_threads == current_threads:
                    valid_dirs.append(dir_name)
                count += 1

        print(len(valid_dirs))
        np.random.shuffle(valid_dirs)
        for ii, dir_name in enumerate(valid_dirs):

            if not os.path.exists(dir_name):
                continue

            json_file = os.path.join(dir_name, JSON_FILENAME)
            if not os.path.isfile(json_file):
                continue

            cidx = 0
            certificate_file = os.path.join(dir_name,
                                            replay_certificate_filenames[cidx])
            already_checked = False
            while os.path.isfile(certificate_file):
                cidx += 1
                if cidx == args.num_replays:
                    already_checked = True
                    break
                certificate_file = os.path.join(
                    dir_name, replay_certificate_filenames[cidx])
            if already_checked:
                continue

            print(ii)
            if not os.path.isfile(certificate_file):
                total_checks += 1. / args.num_replays
                failed = False

                with open(json_file) as f:
                    print("check %d/%d for file '%s'" %
                          (cidx + 1, args.num_replays, json_file))
                    try:
                        traj_data = json.load(f)
                        env.set_task(traj_data, args, reward_type='dense')
                    except json.decoder.JSONDecodeError:
                        failed = True
                        json_fails += 1

                if not failed:
                    steps_taken = None
                    try:
                        steps_taken = replay_json(env, json_file)
                    except Exception as e:
                        import traceback
                        traceback.print_exc()
                        failed = True
                        crash_fails += 1

                        if str(e) not in errors:
                            errors[str(e)] = 0
                        errors[str(e)] += 1
                        print("%%%%%%%%%%")
                        es = sum([errors[er] for er in errors])
                        print("\terrors (%d):" % es)
                        for er, v in sorted(errors.items(),
                                            key=lambda kv: kv[1],
                                            reverse=True):
                            # if v / es < 0.01:  # stop showing below 1% of errors.
                            #     break
                            print("\t(%.2f) (%d)\t%s" % (v / es, v, er))
                        print("%%%%%%%%%%")

                        if cidx > 1:
                            print(
                                "WARNING: replay that has succeeded before has failed at attempt %d"
                                % cidx)
                            nondet_fails += 1

                    if steps_taken is not None:  # executed without crashing, so now we need to verify completion.
                        goal_satisfied = env.get_goal_satisfied()

                        if goal_satisfied:
                            with open(certificate_file, 'w') as f:
                                f.write('%d' % steps_taken)
                        else:
                            failed = True
                            unsat_fails += 1
                            print("Goal was not satisfied after execution!")

                if failed:
                    # Mark one failure and count the remainder of checks for this instance into the total.
                    total_failures += 1
                    total_checks += args.num_replays - (
                        (cidx + 1) / float(args.num_replays))

                    failure_list.append(json_file)
                    if args.failure_filename is not None:
                        with open(args.failure_filename, 'a') as f:
                            f.write("%s\n" % json_file)
                    # If we're deleting bad trajectories, do that here.
                    if args.move_failed_trajectories is not None:
                        print("Relocating failed trajectory '%s' to '%s'" %
                              (dir_name,
                               os.path.join(args.move_failed_trajectories)))
                        try:
                            shutil.move(dir_name,
                                        args.move_failed_trajectories)
                        except shutil.Error as e:
                            print(
                                "WARNING: failed to perform move; error follows; deleting instead"
                            )
                            print(repr(e))
                            shutil.rmtree(dir_name)
                    if args.remove_failed_trajectories:
                        print("Removing failed trajectory '%s'" % dir_name)
                        shutil.rmtree(dir_name)

                print("-------------------------")
                print("Success Rate: %.2f/%.2f = %.3f" %
                      (total_checks - total_failures, total_checks,
                       float(total_checks - total_failures) /
                       float(total_checks)))
                if total_failures > 0:
                    print("Non-deterministic failure: %d/%d = %.3f" %
                          (nondet_fails, total_failures,
                           float(nondet_fails) / total_failures))
                    print("Failures by crash: %d/%d = %.3f" %
                          (crash_fails, total_failures,
                           float(crash_fails) / total_failures))
                    print("Failures by unsatisfied: %d/%d = %.3f" %
                          (unsat_fails, total_failures,
                           float(unsat_fails) / total_failures))
                    print("Failures by json decode error: %d/%d = %.3f" %
                          (json_fails, total_failures,
                           float(json_fails) / total_failures))
                print("-------------------------")

        if not args.in_parallel:
            continue_check = False
        else:
            time.sleep(60)
Ejemplo n.º 2
0
    class Thor(threading.Thread):
        def __init__(self, queue, train_eval="train"):
            Thread.__init__(self)
            self.action_queue = queue
            self.mask_rcnn = None
            self.env = None
            self.train_eval = train_eval
            self.controller_type = "oracle"

        def run(self):
            while True:
                action, reset, task_file = self.action_queue.get()
                try:
                    if reset:
                        self.reset(task_file)
                    else:
                        self.step(action)
                finally:
                    self.action_queue.task_done()

        def init_env(self, config):
            self.config = config

            screen_height = config['env']['thor']['screen_height']
            screen_width = config['env']['thor']['screen_width']
            smooth_nav = config['env']['thor']['smooth_nav']
            save_frames_to_disk = config['env']['thor']['save_frames_to_disk']

            if not self.env:
                self.env = ThorEnv(player_screen_height=screen_height,
                                   player_screen_width=screen_width,
                                   smooth_nav=smooth_nav,
                                   save_frames_to_disk=save_frames_to_disk)
            self.controller_type = self.config['controller']['type']
            self._done = False
            self._res = ()
            self._feedback = ""
            self.expert = HandCodedThorAgent(self.env, max_steps=200)
            self.prev_command = ""
            self.load_mask_rcnn()

        def load_mask_rcnn(self):
            # load pretrained MaskRCNN model if required
            if 'mrcnn' in self.config['controller'][
                    'type'] and not self.mask_rcnn:
                model_path = os.path.join(
                    os.environ['ALFRED_ROOT'],
                    self.config['mask_rcnn']['pretrained_model_path'])
                self.mask_rcnn = load_pretrained_model(model_path)

        def set_task(self, task_file):
            self.task_file = task_file
            self.traj_root = os.path.dirname(task_file)
            with open(task_file, 'r') as f:
                self.traj_data = json.load(f)

        def reset(self, task_file):
            assert self.env
            assert self.controller_type

            self.set_task(task_file)

            # scene setup
            scene_num = self.traj_data['scene']['scene_num']
            object_poses = self.traj_data['scene']['object_poses']
            dirty_and_empty = self.traj_data['scene']['dirty_and_empty']
            object_toggles = self.traj_data['scene']['object_toggles']
            scene_name = 'FloorPlan%d' % scene_num
            self.env.reset(scene_name)
            self.env.restore_scene(object_poses, object_toggles,
                                   dirty_and_empty)

            # recording
            save_frames_path = self.config['env']['thor']['save_frames_path']
            self.env.save_frames_path = os.path.join(
                save_frames_path, self.traj_root.replace('../', ''))

            # initialize to start position
            self.env.step(dict(
                self.traj_data['scene']['init_action']))  # print goal instr
            task_desc = get_templated_task_desc(self.traj_data)
            print("Task: %s" % task_desc)

            # print("Task: %s" % (self.traj_data['turk_annotations']['anns'][0]['task_desc']))

            # setup task for reward
            class args:
                pass

            args.reward_config = os.path.join(os.environ['ALFRED_ROOT'],
                                              'agents/config/rewards.json')
            self.env.set_task(self.traj_data, args, reward_type='dense')

            # set controller
            self.controller_type = self.config['controller']['type']
            self.goal_desc_human_anns_prob = self.config['env'][
                'goal_desc_human_anns_prob']
            load_receps = self.config['controller']['load_receps']
            debug = self.config['controller']['debug']

            if self.controller_type == 'oracle':
                self.controller = OracleAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob)
            elif self.controller_type == 'oracle_astar':
                self.controller = OracleAStarAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob)
            elif self.controller_type == 'mrcnn':
                self.controller = MaskRCNNAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    pretrained_model=self.mask_rcnn,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob,
                    save_detections_to_disk=self.env.save_frames_to_disk,
                    save_detections_path=self.env.save_frames_path)
            elif self.controller_type == 'mrcnn_astar':
                self.controller = MaskRCNNAStarAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    pretrained_model=self.mask_rcnn,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob,
                    save_detections_to_disk=self.env.save_frames_to_disk,
                    save_detections_path=self.env.save_frames_path)
            else:
                raise NotImplementedError()

            # zero steps
            self.steps = 0

            # reset expert state
            self.expert.reset(task_file)
            self.prev_command = ""

            # return intro text
            self._feedback = self.controller.feedback
            self._res = self.get_info()

            return self._feedback

        def step(self, action):
            if not self._done:
                # take action
                self.prev_command = str(action)
                self._feedback = self.controller.step(action)
                self._res = self.get_info()
                if self.env.save_frames_to_disk:
                    self.record_action(action)
            self.steps += 1

        def get_results(self):
            return self._res

        def record_action(self, action):
            txt_file = os.path.join(self.env.save_frames_path, 'action.txt')
            with open(txt_file, 'a+') as f:
                f.write("%s\r\n" % str(action))

        def get_info(self):
            won = self.env.get_goal_satisfied()
            pcs = self.env.get_goal_conditions_met()
            goal_condition_success_rate = pcs[0] / float(pcs[1])
            acs = self.controller.get_admissible_commands()

            # expert action
            if self.train_eval == "train":
                game_state = {
                    'admissible_commands': acs,
                    'feedback': self._feedback,
                    'won': won
                }
                expert_actions = ["look"]
                try:
                    if not self.prev_command:
                        self.expert.observe(game_state['feedback'])
                    else:
                        next_action = self.expert.act(game_state, 0, won,
                                                      self.prev_command)
                        if next_action in acs:
                            expert_actions = [next_action]
                except HandCodedAgentTimeout:
                    print("Expert Timeout")
                except Exception as e:
                    print(e)
                    traceback.print_exc()
            else:
                expert_actions = []

            training_method = self.config["general"]["training_method"]
            if training_method == "dqn":
                max_nb_steps_per_episode = self.config["rl"]["training"][
                    "max_nb_steps_per_episode"]
            elif training_method == "dagger":
                max_nb_steps_per_episode = self.config["dagger"]["training"][
                    "max_nb_steps_per_episode"]
            else:
                raise NotImplementedError
            self._done = won or self.steps > max_nb_steps_per_episode
            return (self._feedback, self._done, acs, won,
                    goal_condition_success_rate, expert_actions)

        def get_last_frame(self):
            return self.env.last_event.frame[:, :, ::-1]

        def get_exploration_frames(self):
            return self.controller.get_exploration_frames()