def run():
    print(all_scene_numbers)
    # create env and agent
    env = ThorEnv()
    while len(all_scene_numbers) > 0:
        lock.acquire()
        scene_num = all_scene_numbers.pop()
        lock.release()
        fn = os.path.join('layouts', ('FloorPlan%d-layout.npy') % scene_num)
        if os.path.isfile(fn):
            print("file %s already exists; skipping this floorplan" % fn)
            continue

        openable_json_file = os.path.join(
            'layouts', ('FloorPlan%d-openable.json') % scene_num)
        scene_objs_json_file = os.path.join(
            'layouts', ('FloorPlan%d-objects.json') % scene_num)

        scene_name = ('FloorPlan%d') % scene_num
        print('Running ' + scene_name)
        event = env.reset(scene_name,
                          render_image=False,
                          render_depth_image=False,
                          render_class_image=False,
                          render_object_image=True)
        agent_height = event.metadata['agent']['position']['y']

        scene_objs = list(
            set([obj['objectType'] for obj in event.metadata['objects']]))
        with open(scene_objs_json_file, 'w') as sof:
            json.dump(scene_objs, sof, sort_keys=True, indent=4)

        # Get all the reachable points through Unity for this step size.
        event = env.step(
            dict(action='GetReachablePositions',
                 gridSize=constants.AGENT_STEP_SIZE /
                 constants.RECORD_SMOOTHING_FACTOR))
        if event.metadata['actionReturn'] is None:
            print("ERROR: scene %d 'GetReachablePositions' returns None" %
                  scene_num)
        else:
            reachable_points = set()
            for point in event.metadata['actionReturn']:
                reachable_points.add((point['x'], point['z']))
            print("scene %d got %d reachable points, now checking" %
                  (scene_num, len(reachable_points)))

            # Pick up a small object to use in testing whether points are good for openable objects.
            open_test_objs = {
                'ButterKnife', 'CD', 'CellPhone', 'Cloth', 'CreditCard',
                'DishSponge', 'Fork', 'KeyChain', 'Pen', 'Pencil', 'SoapBar',
                'Spoon', 'Watch'
            }
            good_obj_point = None
            good_obj_point = get_obj(env, open_test_objs, reachable_points,
                                     agent_height, scene_name, good_obj_point)

            best_open_point = {
            }  # map from object names to the best point from which they can be successfully opened
            best_sem_coverage = {
            }  # number of pixels in the semantic map of the receptacle at the existing best openpt
            checked_points = set()
            scene_receptacles = set()
            for point in reachable_points:
                point_is_valid = True
                action = {
                    'action': 'TeleportFull',
                    'x': point[0],
                    'y': agent_height,
                    'z': point[1],
                }
                event = env.step(action)
                if event.metadata['lastActionSuccess']:
                    for horizon in [-30, 0, 30]:
                        action = {
                            'action': 'TeleportFull',
                            'x': point[0],
                            'y': agent_height,
                            'z': point[1],
                            'rotateOnTeleport': True,
                            'rotation': 0,
                            'horizon': horizon
                        }
                        event = env.step(action)
                        if not event.metadata['lastActionSuccess']:
                            point_is_valid = False
                            break
                        for rotation in range(3):
                            action = {'action': 'RotateLeft'}
                            event = env.step(action)
                            if not event.metadata['lastActionSuccess']:
                                point_is_valid = False
                                break
                        if not point_is_valid:
                            break
                    if point_is_valid:
                        checked_points.add(point)
                    else:
                        continue

                    # Check whether we can open objects from here in any direction with any tilt.
                    for rotation in range(4):
                        # First try up, then down, then return to the horizon before moving again.
                        for horizon in [-30, 0, 30]:

                            action = {
                                'action': 'TeleportFull',
                                'x': point[0],
                                'y': agent_height,
                                'z': point[1],
                                'rotateOnTeleport': True,
                                'rotation': rotation * 90,
                                'horizon': horizon
                            }
                            event = env.step(action)
                            for obj in event.metadata['objects']:
                                if (obj['visible'] and obj['objectId']
                                        and obj['receptacle']
                                        and not obj['pickupable']
                                        and obj['objectType']
                                        in constants.VAL_RECEPTACLE_OBJECTS):
                                    obj_name = obj['objectId']
                                    obj_point = (obj['position']['x'],
                                                 obj['position']['y'])
                                    scene_receptacles.add(obj_name)

                                    # Go ahead and attempt to close the object from this position if it's open.
                                    if obj['openable'] and obj['isOpen']:
                                        close_action = {
                                            'action': 'CloseObject',
                                            'objectId': obj['objectId']
                                        }
                                        event = env.step(close_action)

                                    point_to_recep = np.linalg.norm(
                                        np.array(point) - np.array(obj_point))
                                    if len(env.last_event.
                                           metadata['inventoryObjects']) > 0:
                                        inv_obj = env.last_event.metadata[
                                            'inventoryObjects'][0]['objectId']
                                    else:
                                        inv_obj = None

                                    # Heuristic implemented in task_game_state has agent 0.5 or farther in agent space.
                                    heuristic_far_enough_from_recep = 0.5 < point_to_recep
                                    # Ensure this point affords a larger view according to the semantic segmentation
                                    # of the receptacle than the existing.
                                    point_sem_coverage = get_mask_of_obj(
                                        env, obj['objectId'])
                                    if point_sem_coverage is None:
                                        use_sem_heuristic = False
                                        better_sem_covereage = False
                                    else:
                                        use_sem_heuristic = True
                                        better_sem_covereage = (
                                            obj_name not in best_sem_coverage
                                            or
                                            best_sem_coverage[obj_name] is None
                                            or point_sem_coverage >
                                            best_sem_coverage[obj_name])
                                    # Ensure that this point is farther away than our existing best candidate.
                                    # We'd like to open each receptacle from as far away as possible while retaining
                                    # the ability to pick/place from it.
                                    farther_than_existing_good_point = (
                                        obj_name not in best_open_point
                                        or point_to_recep > np.linalg.norm(
                                            np.array(point) - np.array(
                                                best_open_point[obj_name][:2]))
                                    )
                                    # If we don't have an inventory object, though, we'll fall back to the heuristic
                                    # of being able to open/close as _close_ as possible.
                                    closer_than_existing_good_point = (
                                        obj_name not in best_open_point
                                        or point_to_recep < np.linalg.norm(
                                            np.array(point) - np.array(
                                                best_open_point[obj_name][:2]))
                                    )
                                    # Semantic segmentation heuristic.
                                    if ((use_sem_heuristic
                                         and heuristic_far_enough_from_recep
                                         and better_sem_covereage) or
                                        (not use_sem_heuristic and
                                         # Distance heuristics.
                                         (heuristic_far_enough_from_recep and
                                          (inv_obj and
                                           farther_than_existing_good_point) or
                                          (not inv_obj and
                                           closer_than_existing_good_point)))):
                                        if obj['openable']:
                                            action = {
                                                'action': 'OpenObject',
                                                'objectId': obj['objectId']
                                            }
                                            event = env.step(action)
                                        if not obj[
                                                'openable'] or event.metadata[
                                                    'lastActionSuccess']:
                                            # We can open the object, so try placing our small inventory obj inside.
                                            # If it can be placed inside and retrieved, then this is a safe point.
                                            action = {
                                                'action':
                                                'PutObject',
                                                'objectId':
                                                inv_obj,
                                                'receptacleObjectId':
                                                obj['objectId'],
                                                'forceAction':
                                                True,
                                                'placeStationary':
                                                True
                                            }
                                            if inv_obj:
                                                event = env.step(action)
                                            if inv_obj is None or event.metadata[
                                                    'lastActionSuccess']:
                                                action = {
                                                    'action': 'PickupObject',
                                                    'objectId': inv_obj
                                                }
                                                if inv_obj:
                                                    event = env.step(action)
                                                if inv_obj is None or event.metadata[
                                                        'lastActionSuccess']:

                                                    # Finally, ensure we can also close the receptacle.
                                                    if obj['openable']:
                                                        action = {
                                                            'action':
                                                            'CloseObject',
                                                            'objectId':
                                                            obj['objectId']
                                                        }
                                                        event = env.step(
                                                            action)
                                                    if not obj['openable'] or event.metadata[
                                                            'lastActionSuccess']:

                                                        # We can put/pick our inv object into the receptacle from here.
                                                        # We have already ensured this point is farther than any
                                                        # existing best, so this is the new best.
                                                        best_open_point[
                                                            obj_name] = [
                                                                point[0],
                                                                point[1],
                                                                rotation * 90,
                                                                horizon
                                                            ]
                                                        best_sem_coverage[
                                                            obj_name] = point_sem_coverage

                                                # We could not retrieve our inv object, so we need to go get another one
                                                else:
                                                    good_obj_point = get_obj(
                                                        env, open_test_objs,
                                                        reachable_points,
                                                        agent_height,
                                                        scene_name,
                                                        good_obj_point)
                                                    action = {
                                                        'action':
                                                        'TeleportFull',
                                                        'x': point[0],
                                                        'y': agent_height,
                                                        'z': point[1],
                                                        'rotateOnTeleport':
                                                        True,
                                                        'rotation':
                                                        rotation * 90,
                                                        'horizon': horizon
                                                    }
                                                    event = env.step(action)

                                    # Regardless of what happened up there, try to close the receptacle again if
                                    # it remained open.
                                    if obj['isOpen']:
                                        action = {
                                            'action': 'CloseObject',
                                            'objectId': obj['objectId']
                                        }
                                        event = env.step(action)

            essential_objs = []
            if scene_num in constants.SCENE_TYPE["Kitchen"]:
                essential_objs.extend(["Microwave", "Fridge"])
            for obj in essential_objs:
                if not np.any([obj in obj_key for obj_key in best_open_point]):
                    print(
                        "WARNING: Essential object %s has no open points in scene %d"
                        % (obj, scene_num))

            print(
                "scene %d found open/pick/place/close positions for %d/%d receptacle objects"
                % (scene_num, len(best_open_point), len(scene_receptacles)))
            with open(openable_json_file, 'w') as f:
                json.dump(best_open_point, f, sort_keys=True, indent=4)

            print("scene %d reachable %d, checked %d; taking intersection" %
                  (scene_num, len(reachable_points), len(checked_points)))

            points = np.array(list(checked_points))[:, :2]
            points = points[np.lexsort((points[:, 0], points[:, 1])), :]
            np.save(fn, points)

    env.stop()
    print('Done')
Beispiel #2
0
def main(args, thread_num=0):

    print(thread_num)
    # settings
    alfred_dataset_path = '../data/json_2.1.0/train'

    constants.DATA_SAVE_PATH = args.save_path
    print("Force Unsave Data: %s" % str(args.force_unsave))

    # Set up data structure to track dataset balance and use for selecting next parameters.
    # In actively gathering data, we will try to maximize entropy for each (e.g., uniform spread of goals,
    # uniform spread over patient objects, uniform recipient objects, and uniform scenes).
    succ_traj = pd.DataFrame(
        columns=["goal", "pickup", "movable", "receptacle", "scene"])

    # objects-to-scene and scene-to-objects database
    for scene_type, ids in constants.SCENE_TYPE.items():
        for id in ids:
            obj_json_file = os.path.join('layouts',
                                         'FloorPlan%d-objects.json' % id)
            with open(obj_json_file, 'r') as of:
                scene_objs = json.load(of)

            id_str = str(id)
            scene_id_to_objs[id_str] = scene_objs
            for obj in scene_objs:
                if obj not in obj_to_scene_ids:
                    obj_to_scene_ids[obj] = set()
                obj_to_scene_ids[obj].add(id_str)

    # scene-goal database
    for g in constants.GOALS:
        for st in constants.GOALS_VALID[g]:
            scenes_for_goal[g].extend(
                [str(s) for s in constants.SCENE_TYPE[st]])
        scenes_for_goal[g] = set(scenes_for_goal[g])

    # scene-type database
    for st in constants.SCENE_TYPE:
        for s in constants.SCENE_TYPE[st]:
            scene_to_type[str(s)] = st

    # pre-populate counts in this structure using saved trajectories path.
    succ_traj, full_traj = load_successes_from_disk(args.save_path, succ_traj,
                                                    args.just_examine,
                                                    args.repeats_per_cond)
    if args.just_examine:
        print_successes(succ_traj)
        return

    print(succ_traj.groupby('goal').count())
    # pre-populate failed trajectories.
    fail_traj = load_fails_from_disk(args.save_path)
    print("Loaded %d known failed tuples" % len(fail_traj))

    # create env and agent
    env = ThorEnv(x_display='0.%d' % (thread_num % 2))

    game_state = TaskGameStateFullKnowledge(env)
    agent = DeterministicPlannerAgent(thread_id=0, game_state=game_state)

    errors = {
    }  # map from error strings to counts, to be shown after every failure.
    goal_candidates = constants.GOALS[:]
    pickup_candidates = list(set().union(*[
        constants.VAL_RECEPTACLE_OBJECTS[
            obj]  # Union objects that can be placed.
        for obj in constants.VAL_RECEPTACLE_OBJECTS
    ]))
    pickup_candidates = [
        p for p in pickup_candidates
        if constants.OBJ_PARENTS[p] in obj_to_scene_ids
    ]
    movable_candidates = list(
        set(constants.MOVABLE_RECEPTACLES).intersection(
            obj_to_scene_ids.keys()))
    receptacle_candidates = [obj for obj in constants.VAL_RECEPTACLE_OBJECTS
                             if obj not in constants.MOVABLE_RECEPTACLES and obj in obj_to_scene_ids] + \
                            [obj for obj in constants.VAL_ACTION_OBJECTS["Toggleable"]
                             if obj in obj_to_scene_ids]

    # toaster isn't interesting in terms of producing linguistic diversity
    receptacle_candidates.remove('Toaster')
    receptacle_candidates.sort()

    scene_candidates = list(scene_id_to_objs.keys())

    n_until_load_successes = args.async_load_every_n_samples
    print_successes(succ_traj)
    task_sampler = sample_task_params(succ_traj, full_traj, fail_traj,
                                      goal_candidates, pickup_candidates,
                                      movable_candidates,
                                      receptacle_candidates, scene_candidates)

    # main generation loop
    # keeps trying out new task tuples as trajectories either fail or suceed
    while True:
        # for _ in range(20):
        for ii, json_path in enumerate(
                glob.iglob(os.path.join(alfred_dataset_path, "**",
                                        "traj_data.json"),
                           recursive=True)):
            # if ii % args.num_threads == thread_num:
            # if ii == 5:
            sampled_task = json_path.split('/')[-3].split('-')
            # sampled_task = next(task_sampler)
            # print("===============")
            # print(ii, json_path)
            print(sampled_task)  # DEBUG
            # print("===============")

            if sampled_task is None:
                sys.exit(
                    "No valid tuples left to sample (all are known to fail or already have %d trajectories"
                    % args.repeats_per_cond)
            gtype, pickup_obj, movable_obj, receptacle_obj, sampled_scene = sampled_task

            sampled_scene = int(sampled_scene)
            print("sampled tuple: " + str((gtype, pickup_obj, movable_obj,
                                           receptacle_obj, sampled_scene)))

            tries_remaining = args.trials_before_fail
            # only try to get the number of trajectories left to make this tuple full.
            target_remaining = args.repeats_per_cond - len(
                succ_traj.loc[(succ_traj['goal'] == gtype)
                              & (succ_traj['pickup'] == pickup_obj) &
                              (succ_traj['movable'] == movable_obj) &
                              (succ_traj['receptacle'] == receptacle_obj) &
                              (succ_traj['scene'] == str(sampled_scene))])
            num_place_fails = 0  # count of errors related to placement failure for no valid positions.

            # continue until we're (out of tries + have never succeeded) or (have gathered the target number of instances)
            while num_place_fails > args.trials_before_fail or target_remaining > 0:

                # environment setup
                constants.pddl_goal_type = gtype
                print("PDDLGoalType: " + constants.pddl_goal_type)
                task_id = create_dirs(gtype, pickup_obj, movable_obj,
                                      receptacle_obj, sampled_scene)

                # setup data dictionary
                setup_data_dict()
                constants.data_dict['task_id'] = task_id
                constants.data_dict['task_type'] = constants.pddl_goal_type
                constants.data_dict['dataset_params'][
                    'video_frame_rate'] = constants.VIDEO_FRAME_RATE

                # plan & execute
                try:
                    # if True:
                    # Agent reset to new scene.
                    constraint_objs = {
                        'repeat': [(
                            constants.OBJ_PARENTS[
                                pickup_obj],  # Generate multiple parent objs.
                            np.random.randint(
                                2 if gtype == "pick_two_obj_and_place" else 1,
                                constants.PICKUP_REPEAT_MAX + 1))],
                        'sparse':
                        [(receptacle_obj.replace('Basin', ''),
                          num_place_fails * constants.RECEPTACLE_SPARSE_POINTS)
                         ]
                    }
                    if movable_obj != "None":
                        constraint_objs['repeat'].append(
                            (movable_obj,
                             np.random.randint(1, constants.PICKUP_REPEAT_MAX +
                                               1)))
                    for obj_type in scene_id_to_objs[str(sampled_scene)]:
                        if (obj_type in pickup_candidates and
                                obj_type != constants.OBJ_PARENTS[pickup_obj]
                                and obj_type != movable_obj):
                            constraint_objs['repeat'].append(
                                (obj_type,
                                 np.random.randint(
                                     1,
                                     constants.MAX_NUM_OF_OBJ_INSTANCES + 1)))
                    if gtype in goal_to_invalid_receptacle:
                        constraint_objs['empty'] = [
                            (r.replace('Basin', ''), num_place_fails *
                             constants.RECEPTACLE_EMPTY_POINTS)
                            for r in goal_to_invalid_receptacle[gtype]
                        ]
                    constraint_objs['seton'] = []
                    if gtype == 'look_at_obj_in_light':
                        constraint_objs['seton'].append(
                            (receptacle_obj, False))
                    if num_place_fails > 0:
                        print(
                            "Failed %d placements in the past; increased free point constraints: "
                            % num_place_fails + str(constraint_objs))
                    scene_info = {
                        'scene_num': sampled_scene,
                        'random_seed': random.randint(0, 2**32)
                    }
                    info = agent.reset(scene=scene_info, objs=constraint_objs)

                    # Problem initialization with given constraints.
                    task_objs = {'pickup': pickup_obj}
                    if movable_obj != "None":
                        task_objs['mrecep'] = movable_obj
                    if gtype == "look_at_obj_in_light":
                        task_objs['toggle'] = receptacle_obj
                    else:
                        task_objs['receptacle'] = receptacle_obj
                    agent.setup_problem({'info': info},
                                        scene=scene_info,
                                        objs=task_objs)

                    # Now that objects are in their initial places, record them.
                    object_poses = [{
                        'objectName':
                        obj['name'].split('(Clone)')[0],
                        'position':
                        obj['position'],
                        'rotation':
                        obj['rotation']
                    } for obj in env.last_event.metadata['objects']
                                    if obj['pickupable']]
                    dirty_and_empty = gtype == 'pick_clean_then_place_in_recep'
                    object_toggles = [{
                        'objectType': o,
                        'stateChange': 'toggleable',
                        'isToggled': v
                    } for o, v in constraint_objs['seton']]
                    constants.data_dict['scene']['object_poses'] = object_poses
                    constants.data_dict['scene'][
                        'dirty_and_empty'] = dirty_and_empty
                    constants.data_dict['scene'][
                        'object_toggles'] = object_toggles

                    # Pre-restore the scene to cause objects to "jitter" like they will when the episode is replayed
                    # based on stored object and toggle info. This should put objects closer to the final positions they'll
                    # be inlay at inference time (e.g., mugs fallen and broken, knives fallen over, etc.).
                    print("Performing reset via thor_env API")
                    env.reset(sampled_scene)
                    print("Performing restore via thor_env API")
                    env.restore_scene(object_poses, object_toggles,
                                      dirty_and_empty)
                    event = env.step(
                        dict(constants.data_dict['scene']['init_action']))

                    terminal = False
                    while not terminal and agent.current_frame_count <= constants.MAX_EPISODE_LENGTH:
                        action_dict = agent.get_action(None)
                        agent.step(action_dict)
                        reward, terminal = agent.get_reward()

                    dump_data_dict()
                    save_video()
                # else:
                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    print("Error: " + repr(e))
                    print("Invalid Task: skipping...")
                    if args.debug:
                        print(traceback.format_exc())

                    deleted = delete_save(args.in_parallel)
                    if not deleted:  # another thread is filling this task successfully, so leave it alone.
                        target_remaining = 0  # stop trying to do this task.
                    else:
                        if str(
                                e
                        ) == "API Action Failed: No valid positions to place object found":
                            # Try increasing the space available on sparse and empty flagged objects.
                            num_place_fails += 1
                            tries_remaining -= 1
                        else:  # generic error
                            tries_remaining -= 1

                    estr = str(e)
                    if len(estr) > 120:
                        estr = estr[:120]
                    if estr not in errors:
                        errors[estr] = 0
                    errors[estr] += 1
                    print("%%%%%%%%%%")
                    es = sum([errors[er] for er in errors])
                    print("\terrors (%d):" % es)
                    for er, v in sorted(errors.items(),
                                        key=lambda kv: kv[1],
                                        reverse=True):
                        if v / es < 0.01:  # stop showing below 1% of errors.
                            break
                        print("\t(%.2f) (%d)\t%s" % (v / es, v, er))
                    print("%%%%%%%%%%")

                    continue

                if args.force_unsave:
                    delete_save(args.in_parallel)

                # add to save structure.
                succ_traj = succ_traj.append(
                    {
                        "goal": gtype,
                        "movable": movable_obj,
                        "pickup": pickup_obj,
                        "receptacle": receptacle_obj,
                        "scene": str(sampled_scene)
                    },
                    ignore_index=True)
                target_remaining -= 1
                tries_remaining += args.trials_before_fail  # on success, add more tries for future successes

            # if this combination resulted in a certain number of failures with no successes, flag it as not possible.
            if tries_remaining == 0 and target_remaining == args.repeats_per_cond:
                new_fails = [(gtype, pickup_obj, movable_obj, receptacle_obj,
                              str(sampled_scene))]
                fail_traj = load_fails_from_disk(args.save_path,
                                                 to_write=new_fails)
                print("%%%%%%%%%%")
                print("failures (%d)" % len(fail_traj))
                # print("\t" + "\n\t".join([str(ft) for ft in fail_traj]))
                print("%%%%%%%%%%")

            # if this combination gave us the repeats we wanted, note it as filled.
            if target_remaining == 0:
                full_traj.add((gtype, pickup_obj, movable_obj, receptacle_obj,
                               sampled_scene))

            # if we're sharing with other processes, reload successes from disk to update local copy with others' additions.
            if args.in_parallel:
                if n_until_load_successes > 0:
                    n_until_load_successes -= 1
                else:
                    print(
                        "Reloading trajectories from disk because of parallel processes..."
                    )
                    succ_traj = pd.DataFrame(
                        columns=succ_traj.columns)  # Drop all rows.
                    succ_traj, full_traj = load_successes_from_disk(
                        args.save_path, succ_traj, False,
                        args.repeats_per_cond)
                    print("... Loaded %d trajectories" % len(succ_traj.index))
                    n_until_load_successes = args.async_load_every_n_samples
                    print_successes(succ_traj)
                    task_sampler = sample_task_params(
                        succ_traj, full_traj, fail_traj, goal_candidates,
                        pickup_candidates, movable_candidates,
                        receptacle_candidates, scene_candidates)
                    print(
                        "... Created fresh instance of sample_task_params generator"
                    )
Beispiel #3
0
    class Thor(threading.Thread):
        def __init__(self, queue, train_eval="train"):
            Thread.__init__(self)
            self.action_queue = queue
            self.mask_rcnn = None
            self.env = None
            self.train_eval = train_eval
            self.controller_type = "oracle"

        def run(self):
            while True:
                action, reset, task_file = self.action_queue.get()
                try:
                    if reset:
                        self.reset(task_file)
                    else:
                        self.step(action)
                finally:
                    self.action_queue.task_done()

        def init_env(self, config):
            self.config = config

            screen_height = config['env']['thor']['screen_height']
            screen_width = config['env']['thor']['screen_width']
            smooth_nav = config['env']['thor']['smooth_nav']
            save_frames_to_disk = config['env']['thor']['save_frames_to_disk']

            if not self.env:
                self.env = ThorEnv(player_screen_height=screen_height,
                                   player_screen_width=screen_width,
                                   smooth_nav=smooth_nav,
                                   save_frames_to_disk=save_frames_to_disk)
            self.controller_type = self.config['controller']['type']
            self._done = False
            self._res = ()
            self._feedback = ""
            self.expert = HandCodedThorAgent(self.env, max_steps=200)
            self.prev_command = ""
            self.load_mask_rcnn()

        def load_mask_rcnn(self):
            # load pretrained MaskRCNN model if required
            if 'mrcnn' in self.config['controller'][
                    'type'] and not self.mask_rcnn:
                model_path = os.path.join(
                    os.environ['ALFRED_ROOT'],
                    self.config['mask_rcnn']['pretrained_model_path'])
                self.mask_rcnn = load_pretrained_model(model_path)

        def set_task(self, task_file):
            self.task_file = task_file
            self.traj_root = os.path.dirname(task_file)
            with open(task_file, 'r') as f:
                self.traj_data = json.load(f)

        def reset(self, task_file):
            assert self.env
            assert self.controller_type

            self.set_task(task_file)

            # scene setup
            scene_num = self.traj_data['scene']['scene_num']
            object_poses = self.traj_data['scene']['object_poses']
            dirty_and_empty = self.traj_data['scene']['dirty_and_empty']
            object_toggles = self.traj_data['scene']['object_toggles']
            scene_name = 'FloorPlan%d' % scene_num
            self.env.reset(scene_name)
            self.env.restore_scene(object_poses, object_toggles,
                                   dirty_and_empty)

            # recording
            save_frames_path = self.config['env']['thor']['save_frames_path']
            self.env.save_frames_path = os.path.join(
                save_frames_path, self.traj_root.replace('../', ''))

            # initialize to start position
            self.env.step(dict(
                self.traj_data['scene']['init_action']))  # print goal instr
            task_desc = get_templated_task_desc(self.traj_data)
            print("Task: %s" % task_desc)

            # print("Task: %s" % (self.traj_data['turk_annotations']['anns'][0]['task_desc']))

            # setup task for reward
            class args:
                pass

            args.reward_config = os.path.join(os.environ['ALFRED_ROOT'],
                                              'agents/config/rewards.json')
            self.env.set_task(self.traj_data, args, reward_type='dense')

            # set controller
            self.controller_type = self.config['controller']['type']
            self.goal_desc_human_anns_prob = self.config['env'][
                'goal_desc_human_anns_prob']
            load_receps = self.config['controller']['load_receps']
            debug = self.config['controller']['debug']

            if self.controller_type == 'oracle':
                self.controller = OracleAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob)
            elif self.controller_type == 'oracle_astar':
                self.controller = OracleAStarAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob)
            elif self.controller_type == 'mrcnn':
                self.controller = MaskRCNNAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    pretrained_model=self.mask_rcnn,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob,
                    save_detections_to_disk=self.env.save_frames_to_disk,
                    save_detections_path=self.env.save_frames_path)
            elif self.controller_type == 'mrcnn_astar':
                self.controller = MaskRCNNAStarAgent(
                    self.env,
                    self.traj_data,
                    self.traj_root,
                    pretrained_model=self.mask_rcnn,
                    load_receps=load_receps,
                    debug=debug,
                    goal_desc_human_anns_prob=self.goal_desc_human_anns_prob,
                    save_detections_to_disk=self.env.save_frames_to_disk,
                    save_detections_path=self.env.save_frames_path)
            else:
                raise NotImplementedError()

            # zero steps
            self.steps = 0

            # reset expert state
            self.expert.reset(task_file)
            self.prev_command = ""

            # return intro text
            self._feedback = self.controller.feedback
            self._res = self.get_info()

            return self._feedback

        def step(self, action):
            if not self._done:
                # take action
                self.prev_command = str(action)
                self._feedback = self.controller.step(action)
                self._res = self.get_info()
                if self.env.save_frames_to_disk:
                    self.record_action(action)
            self.steps += 1

        def get_results(self):
            return self._res

        def record_action(self, action):
            txt_file = os.path.join(self.env.save_frames_path, 'action.txt')
            with open(txt_file, 'a+') as f:
                f.write("%s\r\n" % str(action))

        def get_info(self):
            won = self.env.get_goal_satisfied()
            pcs = self.env.get_goal_conditions_met()
            goal_condition_success_rate = pcs[0] / float(pcs[1])
            acs = self.controller.get_admissible_commands()

            # expert action
            if self.train_eval == "train":
                game_state = {
                    'admissible_commands': acs,
                    'feedback': self._feedback,
                    'won': won
                }
                expert_actions = ["look"]
                try:
                    if not self.prev_command:
                        self.expert.observe(game_state['feedback'])
                    else:
                        next_action = self.expert.act(game_state, 0, won,
                                                      self.prev_command)
                        if next_action in acs:
                            expert_actions = [next_action]
                except HandCodedAgentTimeout:
                    print("Expert Timeout")
                except Exception as e:
                    print(e)
                    traceback.print_exc()
            else:
                expert_actions = []

            training_method = self.config["general"]["training_method"]
            if training_method == "dqn":
                max_nb_steps_per_episode = self.config["rl"]["training"][
                    "max_nb_steps_per_episode"]
            elif training_method == "dagger":
                max_nb_steps_per_episode = self.config["dagger"]["training"][
                    "max_nb_steps_per_episode"]
            else:
                raise NotImplementedError
            self._done = won or self.steps > max_nb_steps_per_episode
            return (self._feedback, self._done, acs, won,
                    goal_condition_success_rate, expert_actions)

        def get_last_frame(self):
            return self.env.last_event.frame[:, :, ::-1]

        def get_exploration_frames(self):
            return self.controller.get_exploration_frames()