Пример #1
0
def get_all_objtype_in_room(room_type: str):
    '''
    Get all Object Types in a certern room type: kitchen, living_room, bedroom, bathroom
    '''
    if room_type == "kitchen":
        room_start_index = 0
    elif room_type == "living_room":
        room_start_index = 200
    elif room_type == "bedroom":
        room_start_index = 300
    else:
        room_start_index = 400

    all_obj_type = []

    controller = Controller(scene="FloorPlan1",
                            renderInstanceSegmentation=True,
                            width=1080,
                            height=1080)

    for i in range(1, 31):
        controller.reset(scene="FloorPlan" + str(room_start_index + i))
        event = controller.step("Done")
        all_obj = get_all_objtype_in_event(event)
        for objtype in all_obj:
            if objtype not in all_obj_type:
                all_obj_type.append(objtype)

    controller.stop()
    return all_obj_type
 def run_load(dataset_path, memory, conn):
     init_logging()
     controller = Controller(
         x_display='0.%d' % global_config.actor_gpu,
         visibilityDistance=actor_config.visibilityDistance,
         renderDepthImage=global_config.depth)
     dataset = ActiveDataset(dataset_path, memory, controller, conn)
     dataset.process()
     controller.stop()
Пример #3
0
class Ai2Thor():
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True
        self.do_orbslam = False
        self.do_depth_noise = False
        self.makevideo = True
        # st()

        # these are all map names
        a = np.arange(1, 30)
        b = np.arange(201, 231)
        c = np.arange(301, 331)
        d = np.arange(401, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in list(abcd):
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        random.shuffle(mapnames)
        self.mapnames = mapnames
        self.num_episodes = 1  #len(self.mapnames)

        self.ignore_classes = []
        # classes to save
        self.include_classes = [
            'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel',
            'HandTowel', 'TowelHolder', 'SoapBar', 'ToiletPaper',
            'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan',
            'Candle', 'ScrubBrush', 'Plunger', 'SinkBasin', 'Cloth',
            'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed',
            'Book', 'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil',
            'CellPhone', 'KeyChain', 'Painting', 'CreditCard', 'AlarmClock',
            'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk',
            'Curtains', 'Dresser', 'Watch', 'Television', 'WateringCan',
            'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
            'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat',
            'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit', 'Shelf',
            'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave',
            'CoffeeMachine', 'Fork', 'Fridge', 'WineBottle', 'Spatula',
            'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato',
            'PepperShaker', 'ButterKnife', 'StoveKnob', 'Toaster',
            'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
            'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster',
            'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
            'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll',
            'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
            'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window'
        ]

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25

        self.obj_per_scene = 10

        # self.origin_quaternion = np.quaternion(1, 0, 0, 0)
        # self.origin_rot_vector = quaternion.as_rotation_vector(self.origin_quaternion)

        self.homepath = f'/home/nel/gsarch/aithor/data/test'
        # self.basepath = '/home/nel/gsarch/replica_traj_bed'
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.W = 256
        self.H = 256

        self.fov = 90
        hfov = float(self.fov) * np.pi / 180.
        self.pix_T_camX = np.array([[
            (self.W / 2.) * 1 / np.tan(hfov / 2.), 0., 0., 0.
        ], [0., (self.H / 2.) * 1 / np.tan(hfov / 2.), 0., 0.], [0., 0., 1, 0],
                                    [0., 0., 0, 1]])
        self.pix_T_camX[0, 2] = self.W / 2.
        self.pix_T_camX[1, 2] = self.H / 2.

        self.run_episodes()

    def run_episodes(self):
        self.ep_idx = 0
        # self.objects = []

        for episode in range(self.num_episodes):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames[episode]
            print("MAPNAME=", mapname)

            self.controller = Controller(
                scene=mapname,
                gridSize=0.25,
                width=self.W,
                height=self.H,
                fieldOfView=self.fov,
                renderObjectImage=True,
                renderDepthImage=True,
            )

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run()

            self.controller.stop()
            time.sleep(1)

            self.ep_idx += 1

    def save_datapoint(self, observations, data_path, viewnum, flat_view):
        if self.verbose:
            print("Print Sensor States.", self.agent.state.sensor_states)
        rgb = observations["color_sensor"]
        semantic = observations["semantic_sensor"]
        # st()
        depth = observations["depth_sensor"]
        agent_pos = observations["positions"]
        agent_rot = observations["rotations"]
        # Assuming all sensors have same extrinsics
        color_sensor_pos = observations["positions"]
        color_sensor_rot = observations["rotations"]
        #print("POS ", agent_pos)
        #print("ROT ", color_sensor_rot)
        object_list = observations['object_list']

        # print(viewnum, agent_pos)
        # print(agent_rot)

        if False:
            plt.imshow(rgb)
            plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{viewnum}.png'
            plt.savefig(plt_name)

        save_data = {
            'flat_view': flat_view,
            'objects_info': object_list,
            'rgb_camX': rgb,
            'depth_camX': depth,
            'sensor_pos': color_sensor_pos,
            'sensor_rot': color_sensor_rot
        }

        with open(os.path.join(data_path, str(viewnum) + ".p"), 'wb') as f:
            pickle.dump(save_data, f)
        f.close()

    def quat_from_angle_axis(self, theta: float,
                             axis: np.ndarray) -> np.quaternion:
        r"""Creates a quaternion from angle axis format

        :param theta: The angle to rotate about the axis by
        :param axis: The axis to rotate about
        :return: The quaternion
        """
        axis = axis.astype(np.float)
        axis /= np.linalg.norm(axis)
        return quaternion.from_rotation_vector(theta * axis)

    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    def run(self):
        event = self.controller.step('GetReachablePositions')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        objects = np.random.choice(event.metadata['objects'],
                                   self.obj_per_scene,
                                   replace=False)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        for obj in objects:
            print("Object is ", obj['objectType'])
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue

            # Calculate distance to object center
            # obj_center = np.array(list(obj['position'].values()))
            obj_center = np.array(
                list(obj['axisAlignedBoundingBox']['center'].values()))

            # obj_center = np.array(list(obj['position'].values()))

            #print(obj_center)
            obj_center = np.expand_dims(obj_center, axis=0)
            #print(obj_center)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where(
                (distances > self.radius_min) * (distances < self.radius_max))]
            # if not valid_pts:
            # continue

            # plot valid points that we happen to select
            # self.plot_navigable_points(valid_pts)

            # Bin points based on angles [vertical_angle (10 deg/bin), horizontal_angle (10 deg/bin)]
            valid_pts_shift = valid_pts - obj_center

            dz = valid_pts_shift[:, 2]
            dx = valid_pts_shift[:, 0]
            dy = valid_pts_shift[:, 1]

            # Get yaw for binning
            valid_yaw = np.degrees(np.arctan2(dz, dx))

            # # pitch calculation
            # dxdz_norm = np.sqrt((dx * dx) + (dz * dz))
            # valid_pitch = np.degrees(np.arctan2(dy,dxdz_norm))

            # binning yaw around object
            # nbins = 18

            nbins = 18
            bins = np.linspace(-180, 180, nbins + 1)
            bin_yaw = np.digitize(valid_yaw, bins)

            num_valid_bins = np.unique(bin_yaw).size

            # spawns_per_bin = int(self.num_views / num_valid_bins) + 2

            if False:
                import matplotlib.cm as cm
                colors = iter(cm.rainbow(np.linspace(0, 1, nbins)))
                plt.figure(2)
                plt.clf()
                print(np.unique(bin_yaw))
                for bi in range(nbins):
                    cur_bi = np.where(bin_yaw == (bi + 1))
                    points = valid_pts[cur_bi]
                    x_sample = points[:, 0]
                    z_sample = points[:, 2]
                    plt.plot(z_sample, x_sample, 'o', color=next(colors))
                plt.plot(self.nav_pts[:, 2],
                         self.nav_pts[:, 0],
                         'x',
                         color='red')
                plt.plot(obj_center[:, 2],
                         obj_center[:, 0],
                         'x',
                         color='black')
                plt_name = '/home/nel/gsarch/aithor/data/valid.png'
                plt.savefig(plt_name)

            if num_valid_bins == 0:
                continue

            spawns_per_bin = int(self.num_views / num_valid_bins) + 2
            # print(f'spawns_per_bin: {spawns_per_bin}')

            action = "do_nothing"
            episodes = []
            valid_pts_selected = []
            cnt = 0
            for b in range(nbins):

                # get all angle indices in the current bin range
                # st()
                inds_bin_cur = np.where(
                    bin_yaw == (b + 1))  # bins start 1 so need +1
                inds_bin_cur = list(inds_bin_cur[0])
                if len(inds_bin_cur) == 0:
                    continue

                for s in range(spawns_per_bin):
                    observations = {}

                    if len(inds_bin_cur) == 0:
                        continue

                    rand_ind = np.random.randint(0, len(inds_bin_cur))
                    s_ind = inds_bin_cur.pop(rand_ind)

                    # st()
                    # s_ind = np.random.choice(inds_bin_cur)
                    #s_ind = inds_bin_cur[0][0]
                    pos_s = valid_pts[s_ind]
                    valid_pts_selected.append(pos_s)

                    # event = self.controller.step('TeleportFull', x=pos_s[0], y=pos_s[1], z=pos_s[2], rotation=dict(x=0.0, y=180.0, z=0.0), horizon=0.0)
                    # agent_pos = list(event.metadata['agent']['position'].values())
                    # print("Agent rot " , event.metadata['agent']['rotation'])
                    # print("Agent pos " , event.metadata['agent']['position'])
                    # print("Object center ", obj['axisAlignedBoundingBox']['center'])

                    # add height from center of agent to camera
                    pos_s[1] = pos_s[1] + 0.675

                    # YAW calculation - rotate to object
                    agent_to_obj = np.squeeze(
                        obj_center) - pos_s  # + np.array([0.0, 0.675, 0.0]))
                    agent_local_forward = np.array([0, 0, -1.0])  # y, z, x
                    # agent_local_forward = np.array([-1.0, 0, 0]) # y, z, x
                    flat_to_obj = np.array(
                        [agent_to_obj[0], 0.0, agent_to_obj[2]])
                    flat_dist_to_obj = np.linalg.norm(flat_to_obj)
                    flat_to_obj /= flat_dist_to_obj

                    det = (flat_to_obj[0] * agent_local_forward[2] -
                           agent_local_forward[0] * flat_to_obj[2])
                    turn_angle = math.atan2(
                        det, np.dot(agent_local_forward, flat_to_obj))
                    # quat_yaw = self.quat_from_angle_axis(turn_angle, np.array([0, 1.0, 0]))

                    # turn_yaw = np.degrees(quaternion.as_euler_angles(quat_yaw)[1])
                    turn_yaw = np.degrees(turn_angle)
                    # print("Turn Yaw=", turn_yaw)
                    # print("turn1 beg ", turn_yaw)

                    # agent_pos = list(event.metadata['agent']['position'].values())
                    # dist_to_origin = np.sqrt(agent_pos[0]**2 + agent_pos[2]**2)
                    # dist_to_object = np.sqrt((agent_pos[0] - obj_center[0,0])**2 + (agent_pos[2] - obj_center[0,2])**2)

                    # print("Agent rot " , event.metadata['agent']['rotation'])

                    # p0 = agent_pos
                    # p1 = np.squeeze(obj_center)
                    # C = np.cross(p0, p1)
                    # D = np.dot(p0, p1)
                    # NP0 = np.linalg.norm(p0)
                    # if ~np.all(C==0): # check for colinearity
                    #     Z = np.array([[0, -C[2], C[1]], [C[2], 0, -C[0]], [-C[1], C[0], 0]])
                    #     R = (np.eye(3) + Z + Z**2 * (1-D)/(np.linalg.norm(C)**2)) / NP0**2 # rotation matrix
                    # else:
                    #     R = np.sign(D) * (np.linalg.norm(p1) / NP0) # orientation and scaling

                    # quat_rot = quaternion.from_rotation_matrix(R)
                    # turns = np.degrees(quaternion.as_euler_angles(quat_rot))
                    # turn_yaw = turns[1]
                    # print("turn1 ", turn_yaw)
                    # print("turn0 ", turns[0])
                    # print("turn2 ", turns[2])

                    # if dist_to_origin > dist_to_object:
                    #     print('or>ob')
                    #     turn_yaw2 = np.degrees(np.cos(dist_to_object/dist_to_origin))
                    # else:
                    #     print('or>ob')
                    #     turn_yaw2 = np.degrees(np.cos(dist_to_origin/dist_to_object))
                    # print("TURN YAW2, ", turn_yaw2)
                    # Calculate Pitch from head to object
                    turn_pitch = -np.degrees(
                        math.atan2(agent_to_obj[1], flat_dist_to_obj))
                    # movement = "LookUp" if turn_pitch>0 else "LookDown"
                    # event = controller.step(movement, degrees=np.abs(turn_pitch))

                    event = self.controller.step('TeleportFull',
                                                 x=pos_s[0],
                                                 y=pos_s[1],
                                                 z=pos_s[2],
                                                 rotation=dict(x=0.0,
                                                               y=180.0 +
                                                               int(turn_yaw),
                                                               z=0.0),
                                                 horizon=int(turn_pitch))
                    # movement = "RotateRight" if turn_yaw>0 else "RotateLeft"
                    # event = self.controller.step(action='RotateRight', rotation=int(np.abs(turn_yaw)))

                    # movement = "LookDown" if turn_pitch>0 else "LookUp"
                    # event = self.controller.step(movement, degrees=np.abs(turn_pitch))

                    # print("Agent rot " , event.metadata['agent']['rotation'])

                    # angle_ranges = np.arange(0, 360, 15)
                    # angles_test = np.ones_like(angle_ranges) * 15
                    # for i in angle_ranges:
                    #     movement = "RotateRight"
                    #     event = self.controller.step(movement, degrees=15.0)

                    #     rgb = event.frame

                    #     if True:
                    #         plt.imshow(rgb)
                    #         plt_name = f'/home/nel/gsarch/aithor/data/img{i}.png'.format(i)
                    #         plt.savefig(plt_name)
                    # print(event.metadata['agent']['position'])
                    # print(event.metadata['agent']['rotation'])

                    rgb = event.frame

                    rotation_euler_radians = np.radians(
                        np.array([
                            event.metadata['agent']['cameraHorizon'],
                            event.metadata['agent']['rotation']['y'], 0.0
                        ]))

                    observations["positions"] = np.array(
                        list(event.metadata['agent']['position'].values())
                    ) + np.array([0.0, 0.675, 0.0])
                    # print(observations["positions"])
                    # print(pos_s)
                    # observations["rotations"] = quaternion.from_euler_angles(np.radians(np.array(list(event.metadata['agent']['rotation'].values()))))
                    observations["rotations"] = quaternion.from_euler_angles(
                        rotation_euler_radians)

                    # print(observations["positions"])

                    observations["color_sensor"] = rgb
                    observations["depth_sensor"] = event.depth_frame
                    observations[
                        "semantic_sensor"] = event.instance_segmentation_frame

                    if False:
                        plt.imshow(rgb)
                        plt_name = f'/home/nel/gsarch/aithor/data/test/img_true{s}{b}.png'
                        plt.savefig(plt_name)

                    # print("Processed image #", cnt, " for object ", obj['objectType'])

                    semantic = event.instance_segmentation_frame
                    object_id_to_color = event.object_id_to_color
                    color_to_object_id = event.color_to_object_id

                    obj_ids = np.unique(semantic.reshape(
                        -1, semantic.shape[2]),
                                        axis=0)

                    instance_masks = event.instance_masks
                    instance_detections2D = event.instance_detections2D

                    obj_metadata_IDs = []
                    for obj_m in event.metadata['objects']:  #objects:
                        obj_metadata_IDs.append(obj_m['objectId'])

                    object_list = []
                    for obj_idx in range(obj_ids.shape[0]):
                        try:
                            obj_color = tuple(obj_ids[obj_idx])
                            object_id = color_to_object_id[obj_color]
                        except:
                            # print("Skipping ", object_id)
                            continue

                        if object_id not in obj_metadata_IDs:
                            # print("Skipping ", object_id)
                            continue

                        obj_meta_index = obj_metadata_IDs.index(object_id)
                        obj_meta = event.metadata['objects'][obj_meta_index]
                        obj_category_name = obj_meta['objectType']

                        # continue if not visible or not in include classes
                        if obj_category_name not in self.include_classes or not obj_meta[
                                'visible']:
                            continue

                        obj_instance_mask = instance_masks[object_id]
                        obj_instance_detection2D = instance_detections2D[
                            object_id]  # [start_x, start_y, end_x, end_y]
                        obj_instance_detection2D = np.array([
                            obj_instance_detection2D[1],
                            obj_instance_detection2D[0],
                            obj_instance_detection2D[3],
                            obj_instance_detection2D[2]
                        ])  # ymin, xmin, ymax, xmax

                        if False:
                            print(object_id)
                            plt.imshow(obj_instance_mask)
                            plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{s}.png'
                            plt.savefig(plt_name)

                        obj_center_axisAligned = np.array(
                            list(obj_meta['axisAlignedBoundingBox']
                                 ['center'].values()))
                        obj_size_axisAligned = np.array(
                            list(obj_meta['axisAlignedBoundingBox']
                                 ['size'].values()))

                        # print(obj_category_name)

                        if self.verbose:
                            print("Saved class name is : ", obj_category_name)

                        obj_data = {
                            'instance_id': object_id,
                            'category_id': object_id,
                            'category_name': obj_category_name,
                            'bbox_center': obj_center_axisAligned,
                            'bbox_size': obj_size_axisAligned,
                            'mask_2d': obj_instance_mask,
                            'box_2d': obj_instance_detection2D
                        }
                        # object_list.append(obj_instance)
                        object_list.append(obj_data)

                    observations["object_list"] = object_list

                    # check if object visible (make sure agent is not behind a wall)
                    obj_id = obj['objectId']
                    obj_id_to_color = object_id_to_color[obj_id]
                    # if np.sum(obj_ids==object_id_to_color[obj_id]) > 0:
                    if self.verbose:
                        print("episode is valid......")
                    episodes.append(observations)

                    cnt += 1

            if len(episodes) >= self.num_views:
                print(f'num episodes: {len(episodes)}')
                data_folder = obj['name']
                data_path = os.path.join(self.basepath, data_folder)
                print("Saving to ", data_path)
                os.mkdir(data_path)
                # flat_obs = np.random.choice(episodes, self.num_views, replace=False)
                rand_inds = np.sort(
                    np.random.choice(len(episodes),
                                     self.num_views,
                                     replace=False))
                bool_inds = np.zeros(len(episodes), dtype=bool)
                bool_inds[rand_inds] = True
                flat_obs = np.array(episodes)[bool_inds]
                flat_obs = list(flat_obs)
                viewnum = 0
                for obs in flat_obs:
                    self.save_datapoint(obs, data_path, viewnum, True)
                    viewnum += 1
            else:
                print("Not enough episodes:", len(episodes))
Пример #4
0
class RoboThorEnvironment:
    """Wrapper for the robo2thor controller providing additional functionality
    and bookkeeping.

    See [here](https://ai2thor.allenai.org/robothor/documentation) for comprehensive
     documentation on RoboTHOR.

    # Attributes

    controller : The AI2THOR controller.
    config : The AI2THOR controller configuration
    """
    def __init__(self, **kwargs):
        self.config = dict(
            rotateStepDegrees=30.0,
            visibilityDistance=1.0,
            gridSize=0.25,
            agentType="stochastic",
            continuousMode=True,
            snapToGrid=False,
            agentMode="bot",
            width=640,
            height=480,
        )
        recursive_update(self.config, {**kwargs, "agentMode": "bot"})
        self.controller = Controller(**self.config)
        self.known_good_locations: Dict[str, Any] = {
            self.scene_name: copy.deepcopy(self.currently_reachable_points)
        }
        assert len(self.known_good_locations[self.scene_name]) > 10

        # onames = [o['objectId'] for o in self.last_event.metadata['objects']]
        # removed = []
        # for oname in onames:
        #     if 'Painting' in oname:
        #         self.controller.step("RemoveFromScene", objectId=oname)
        #         removed.append(oname)
        # get_logger().info("Removed {} Paintings from {}".format(len(removed), self.scene_name))

        # get_logger().warning("init to scene {} in pos {}".format(self.scene_name, self.agent_state()))
        # npoints = len(self.currently_reachable_points)
        # assert npoints > 100, "only {} reachable points after init".format(npoints)
        self.grids: Dict[str, Tuple[Dict[str, np.array], int, int, int,
                                    int]] = {}
        self.initialize_grid()

    def initialize_grid_dimensions(
        self, reachable_points: Collection[Dict[str, float]]
    ) -> Tuple[int, int, int, int]:
        """Computes bounding box for reachable points quantized with the
        current gridSize."""
        points = {(
            round(p["x"] / self.config["gridSize"]),
            round(p["z"] / self.config["gridSize"]),
        ): p
                  for p in reachable_points}

        assert len(reachable_points) == len(points)

        xmin, xmax = min([p[0] for p in points]), max([p[0] for p in points])
        zmin, zmax = min([p[1] for p in points]), max([p[1] for p in points])

        return xmin, xmax, zmin, zmax

    def access_grid(self, target: str) -> float:
        """Returns the geodesic distance from the quantized location of the
        agent in the current scene's grid to the target object of given
        type."""
        if target not in self.grids[self.scene_name][0]:
            xmin, xmax, zmin, zmax = self.grids[self.scene_name][1:5]
            nx = xmax - xmin + 1
            nz = zmax - zmin + 1
            self.grids[self.scene_name][0][target] = -2 * np.ones(
                (nx, nz), dtype=np.float64)

        p = self.quantized_agent_state()

        if self.grids[self.scene_name][0][target][p[0], p[1]] < -1.5:
            corners = self.path_corners(target)
            dist = self.path_corners_to_dist(corners)
            if dist == float("inf"):
                dist = -1.0  # -1.0 for unreachable
            self.grids[self.scene_name][0][target][p[0], p[1]] = dist
            return dist

        return self.grids[self.scene_name][0][target][p[0], p[1]]

    def initialize_grid(self) -> None:
        """Initializes grid for current scene if not already initialized."""
        if self.scene_name in self.grids:
            return

        self.grids[self.scene_name] = ({}, ) + self.initialize_grid_dimensions(
            self.known_good_locations[self.scene_name])  # type: ignore

    def object_reachable(self, object_type: str) -> bool:
        """Determines whether a path can be computed from the discretized
        current agent location to the target object of given type."""
        return (self.access_grid(object_type) > -0.5
                )  # -1.0 for unreachable, 0.0 for end point

    def point_reachable(self, xyz: Dict[str, float]) -> bool:
        """Determines whether a path can be computed from the current agent
        location to the target point."""
        return self.dist_to_point(
            xyz) > -0.5  # -1.0 for unreachable, 0.0 for end point

    def path_corners(
            self, target: Union[str, Dict[str,
                                          float]]) -> List[Dict[str, float]]:
        """Returns an array with a sequence of xyz dictionaries objects
        representing the corners of the shortest path to the object of given
        type or end point location."""
        pose = self.agent_state()
        position = {k: pose[k] for k in ["x", "y", "z"]}
        # get_logger().debug("initial pos in path corners {} target {}".format(pose, target))
        try:
            if isinstance(target, str):
                path = metrics.get_shortest_path_to_object_type(
                    self.controller,
                    target,
                    position,
                    {**pose["rotation"]} if "rotation" in pose else None,
                )
            else:
                path = metrics.get_shortest_path_to_point(
                    self.controller, position, target)
        except ValueError:
            get_logger().debug("No path to object {} from {} in {}".format(
                target, position, self.scene_name))
            path = []
        finally:
            if isinstance(target, str):
                self.controller.step("TeleportFull", **pose)
                # pass
            new_pose = self.agent_state()
            try:
                assert abs(new_pose["x"] - pose["x"]) < 1e-5, "wrong x"
                assert abs(new_pose["y"] - pose["y"]) < 1e-5, "wrong y"
                assert abs(new_pose["z"] - pose["z"]) < 1e-5, "wrong z"
                assert (abs(new_pose["rotation"]["x"] - pose["rotation"]["x"])
                        < 1e-5), "wrong rotation x"
                assert (abs(new_pose["rotation"]["y"] - pose["rotation"]["y"])
                        < 1e-5), "wrong rotation y"
                assert (abs(new_pose["rotation"]["z"] - pose["rotation"]["z"])
                        < 1e-5), "wrong rotation z"
                assert (abs((new_pose["horizon"] % 360) -
                            (pose["horizon"] % 360)) <
                        1e-5), "wrong horizon {} vs {}".format(
                            (new_pose["horizon"] % 360),
                            (pose["horizon"] % 360))
            except Exception:
                # get_logger().error("new_pose {} old_pose {} in {}".format(new_pose, pose, self.scene_name))
                pass
            # if abs((new_pose['horizon'] % 360) - (pose['horizon'] % 360)) > 1e-5:
            #     get_logger().debug("wrong horizon {} vs {} after path to object {} from {} in {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360), target, position, self.scene_name))
            # else:
            #     get_logger().debug("correct horizon {} vs {} after path to object {} from {} in {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360), target, position, self.scene_name))
            # assert abs((new_pose['horizon'] % 360) - (pose['horizon'] % 360)) < 1e-5, "wrong horizon {} vs {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360))

            # # TODO: the agent will continue with a random horizon from here on
            # target_horizon = (pose['horizon'] % 360) - (360 if (pose['horizon'] % 360) >= 180 else 0)
            # new_pose = self.agent_state()['horizon']
            # update_horizon = (new_pose % 360) - (360 if (new_pose % 360) >= 180 else 0)
            # cond = abs(target_horizon - update_horizon) > 1e-5
            # nmovements = 0
            # while cond:
            #     cond = abs(target_horizon - update_horizon) > 1e-5 and target_horizon > update_horizon
            #     while cond:
            #         self.controller.step("LookDown")
            #         old = update_horizon
            #         new_pose = self.agent_state()['horizon']
            #         update_horizon = (new_pose % 360) - (360 if (new_pose % 360) >= 180 else 0)
            #         get_logger().debug("LookDown horizon {} -> {} ({})".format(old, update_horizon, target_horizon))
            #         nmovements += 1
            #         cond = abs(target_horizon - update_horizon) > 1e-5 and target_horizon > update_horizon
            #
            #     cond = abs(target_horizon - update_horizon) > 1e-5 and target_horizon < update_horizon
            #     while cond:
            #         self.controller.step("LookUp")
            #         old = update_horizon
            #         new_pose = self.agent_state()['horizon']
            #         update_horizon = (new_pose % 360) - (360 if (new_pose % 360) >= 180 else 0)
            #         get_logger().debug("LookUp horizon {} -> {} ({})".format(old, update_horizon, target_horizon))
            #         nmovements += 1
            #         cond = abs(target_horizon - update_horizon) > 1e-5 and target_horizon < update_horizon
            #
            #     cond = abs(target_horizon - update_horizon) > 1e-5
            # get_logger().debug("nmovements {}".format(nmovements))
            # new_pose = self.agent_state()
            # assert abs((new_pose['horizon'] % 360) - (pose['horizon'] % 360)) < 1e-5, "wrong horizon {} vs {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360))

            # try:
            #     assert abs((new_pose['horizon'] % 360) - (pose['horizon'] % 360)) < 1e-5, "wrong horizon {} vs {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360))
            # except Exception:
            #     get_logger().error("wrong horizon {} vs {}".format((new_pose['horizon'] % 360), (pose['horizon'] % 360)))
            #     self.controller.step("TeleportFull", **pose)
            #     assert abs(
            #         (new_pose['horizon'] % 360) - (pose['horizon'] % 360)) < 1e-5, "wrong horizon {} vs {} after teleport full".format(
            #         (new_pose['horizon'] % 360), (pose['horizon'] % 360))
        #     # get_logger().debug("initial pos in path corners {} current pos {} path {}".format(pose, self.agent_state(), path))
        return path

    def path_corners_to_dist(self, corners: Sequence[Dict[str,
                                                          float]]) -> float:
        """Computes the distance covered by the given path described by its
        corners."""

        if len(corners) == 0:
            return float("inf")

        sum = 0.0
        for it in range(1, len(corners)):
            sum += math.sqrt((corners[it]["x"] - corners[it - 1]["x"])**2 +
                             (corners[it]["z"] - corners[it - 1]["z"])**2)
        return sum

    def quantized_agent_state(
            self,
            xz_subsampling: int = 1,
            rot_subsampling: int = 1) -> Tuple[int, int, int]:
        """Quantizes agent location (x, z) to a (subsampled) position in a
        fixed size grid derived from the initial set of reachable points; and
        rotation (around y axis) as a (subsampled) discretized angle given the
        current `rotateStepDegrees`."""
        pose = self.agent_state()
        p = {k: float(pose[k]) for k in ["x", "y", "z"]}

        xmin, xmax, zmin, zmax = self.grids[self.scene_name][1:5]
        x = int(np.clip(round(p["x"] / self.config["gridSize"]), xmin, xmax))
        z = int(np.clip(round(p["z"] / self.config["gridSize"]), zmin, zmax))

        rs = self.config["rotateStepDegrees"] * rot_subsampling
        shifted = pose["rotation"]["y"] + rs / 2
        normalized = shifted % 360.0
        r = int(round(normalized / rs))

        return (x - xmin) // xz_subsampling, (z - zmin) // xz_subsampling, r

    def dist_to_object(self, object_type: str) -> float:
        """Minimal geodesic distance to object of given type from agent's
        current location.

        It might return -1.0 for unreachable targets.
        """
        return self.access_grid(object_type)

    def dist_to_point(self, xyz: Dict[str, float]) -> float:
        """Minimal geodesic distance to end point from agent's current
        location.

        It might return -1.0 for unreachable targets.
        """
        corners = self.path_corners(xyz)
        dist = self.path_corners_to_dist(corners)
        if dist == float("inf"):
            dist = -1.0  # -1.0 for unreachable
        return dist

    def agent_state(self) -> Dict:
        """Return agent position, rotation and horizon."""
        agent_meta = self.last_event.metadata["agent"]
        return {
            **{k: float(v)
               for k, v in agent_meta["position"].items()},
            "rotation":
            {k: float(v)
             for k, v in agent_meta["rotation"].items()},
            "horizon": round(float(agent_meta["cameraHorizon"]), 1),
        }

    def teleport(self,
                 pose: Dict[str, float],
                 rotation: Dict[str, float],
                 horizon: float = 0.0):
        e = self.controller.step(
            "TeleportFull",
            x=pose["x"],
            y=pose["y"],
            z=pose["z"],
            rotation=rotation,
            horizon=horizon,
        )
        return e.metadata["lastActionSuccess"]

    def reset(self, scene_name: str = None) -> None:
        """Resets scene to a known initial state."""
        if scene_name is not None and scene_name != self.scene_name:
            self.controller.reset(scene_name)
            assert self.last_action_success, "Could not reset to new scene"
            if scene_name not in self.known_good_locations:
                self.known_good_locations[scene_name] = copy.deepcopy(
                    self.currently_reachable_points)
                assert len(self.known_good_locations[scene_name]) > 10

            # onames = [o['objectId'] for o in self.last_event.metadata['objects']]
            # removed = []
            # for oname in onames:
            #     if 'Painting' in oname:
            #         self.controller.step("RemoveFromScene", objectId=oname)
            #         removed.append(oname)
            # get_logger().info("Removed {} Paintings from {}".format(len(removed), scene_name))

        # else:
        # assert (
        #     self.scene_name in self.known_good_locations
        # ), "Resetting scene without known good location"
        # get_logger().warning("Resetting {} to {}".format(self.scene_name, self.known_good_locations[self.scene_name]))
        # self.controller.step("TeleportFull", **self.known_good_locations[self.scene_name])
        # assert self.last_action_success, "Could not reset to known good location"

        # npoints = len(self.currently_reachable_points)
        # assert npoints > 100, "only {} reachable points after reset".format(npoints)

        self.initialize_grid()

    def randomize_agent_location(
        self,
        seed: int = None,
        partial_position: Optional[Dict[str, float]] = None
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Teleports the agent to a random reachable location in the scene."""
        if partial_position is None:
            partial_position = {}
        k = 0
        state: Optional[Dict] = None

        while k == 0 or (not self.last_action_success and k < 10):
            # self.reset()
            state = {
                **self.random_reachable_state(seed=seed),
                **partial_position
            }
            # get_logger().debug("picked target location {}".format(state))
            self.controller.step("TeleportFull", **state)
            k += 1

        if not self.last_action_success:
            get_logger().warning((
                "Randomize agent location in scene {} and current random state {}"
                " with seed {} and partial position {} failed in "
                "10 attempts. Forcing the action.").format(
                    self.scene_name, state, seed, partial_position))
            self.controller.step("TeleportFull", **state,
                                 force_action=True)  # type: ignore
            assert self.last_action_success, "Force action failed with {}".format(
                state)

        # get_logger().debug("location after teleport full {}".format(self.agent_state()))
        # self.controller.step("TeleportFull", **self.agent_state())  # TODO only for debug
        # get_logger().debug("location after re-teleport full {}".format(self.agent_state()))

        return self.agent_state()

    def random_reachable_state(
        self,
        seed: Optional[int] = None
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Returns a random reachable location in the scene."""
        if seed is not None:
            random.seed(seed)
        # xyz = random.choice(self.currently_reachable_points)
        assert len(self.known_good_locations[self.scene_name]) > 10
        xyz = copy.deepcopy(
            random.choice(self.known_good_locations[self.scene_name]))
        rotation = random.choice(
            np.arange(0.0, 360.0, self.config["rotateStepDegrees"]))
        horizon = 0.0  # random.choice([0.0, 30.0, 330.0])
        return {
            **{k: float(v)
               for k, v in xyz.items()},
            "rotation": {
                "x": 0.0,
                "y": float(rotation),
                "z": 0.0
            },
            "horizon": float(horizon),
        }

    def known_good_locations_list(self):
        return self.known_good_locations[self.scene_name]

    @property
    def currently_reachable_points(self) -> List[Dict[str, float]]:
        """List of {"x": x, "y": y, "z": z} locations in the scene that are
        currently reachable."""
        self.controller.step(action="GetReachablePositions")
        return self.last_action_return

    @property
    def scene_name(self) -> str:
        """Current ai2thor scene."""
        return self.controller.last_event.metadata["sceneName"].replace(
            "_physics", "")

    @property
    def current_frame(self) -> np.ndarray:
        """Returns rgb image corresponding to the agent's egocentric view."""
        return self.controller.last_event.frame

    @property
    def current_depth(self) -> np.ndarray:
        """Returns depth image corresponding to the agent's egocentric view."""
        return self.controller.last_event.depth_frame

    @property
    def last_event(self) -> ai2thor.server.Event:
        """Last event returned by the controller."""
        return self.controller.last_event

    @property
    def last_action(self) -> str:
        """Last action, as a string, taken by the agent."""
        return self.controller.last_event.metadata["lastAction"]

    @property
    def last_action_success(self) -> bool:
        """Was the last action taken by the agent a success?"""
        return self.controller.last_event.metadata["lastActionSuccess"]

    @property
    def last_action_return(self) -> Any:
        """Get the value returned by the last action (if applicable).

        For an example of an action that returns a value, see
        `"GetReachablePositions"`.
        """
        return self.controller.last_event.metadata["actionReturn"]

    def step(self, action_dict: Dict) -> ai2thor.server.Event:
        """Take a step in the ai2thor environment."""
        return self.controller.step(**action_dict)

    def stop(self):
        """Stops the ai2thor controller."""
        try:
            self.controller.stop()
        except Exception as e:
            get_logger().warning(str(e))

    def all_objects(self) -> List[Dict[str, Any]]:
        """Return all object metadata."""
        return self.controller.last_event.metadata["objects"]

    def all_objects_with_properties(
            self, properties: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Find all objects with the given properties."""
        objects = []
        for o in self.all_objects():
            satisfies_all = True
            for k, v in properties.items():
                if o[k] != v:
                    satisfies_all = False
                    break
            if satisfies_all:
                objects.append(o)
        return objects

    def visible_objects(self) -> List[Dict[str, Any]]:
        """Return all visible objects."""
        return self.all_objects_with_properties({"visible": True})
Пример #5
0
class IThorEnvironment(object):
    """Wrapper for the ai2thor controller providing additional functionality
    and bookkeeping.

    See [here](https://ai2thor.allenai.org/documentation/installation) for comprehensive
     documentation on AI2-THOR.

    # Attributes

    controller : The ai2thor controller.
    """
    def __init__(
        self,
        x_display: Optional[str] = None,
        docker_enabled: bool = False,
        local_thor_build: Optional[str] = None,
        visibility_distance: float = VISIBILITY_DISTANCE,
        fov: float = FOV,
        player_screen_width: int = 300,
        player_screen_height: int = 300,
        quality: str = "Very Low",
        restrict_to_initially_reachable_points: bool = False,
        make_agents_visible: bool = True,
        object_open_speed: float = 1.0,
        simplify_physics: bool = False,
    ) -> None:
        """Initializer.

        # Parameters

        x_display : The x display into which to launch ai2thor (possibly necessarily if you are running on a server
            without an attached display).
        docker_enabled : Whether or not to run thor in a docker container (useful on a server without an attached
            display so that you don't have to start an x display).
        local_thor_build : The path to a local build of ai2thor. This is probably not necessary for your use case
            and can be safely ignored.
        visibility_distance : The distance (in meters) at which objects, in the viewport of the agent,
            are considered visible by ai2thor and will have their "visible" flag be set to `True` in the metadata.
        fov : The agent's camera's field of view.
        player_screen_width : The width resolution (in pixels) of the images returned by ai2thor.
        player_screen_height : The height resolution (in pixels) of the images returned by ai2thor.
        quality : The quality at which to render. Possible quality settings can be found in
            `ai2thor._quality_settings.QUALITY_SETTINGS`.
        restrict_to_initially_reachable_points : Whether or not to restrict the agent to locations in ai2thor
            that were found to be (initially) reachable by the agent (i.e. reachable by the agent after resetting
            the scene). This can be useful if you want to ensure there are only a fixed set of locations where the
            agent can go.
        make_agents_visible : Whether or not the agent should be visible. Most noticable when there are multiple agents
            or when quality settings are high so that the agent casts a shadow.
        object_open_speed : How quickly objects should be opened. High speeds mean faster simulation but also mean
            that opening objects have a lot of kinetic energy and can, possibly, knock other objects away.
        simplify_physics : Whether or not to simplify physics when applicable. Currently this only simplies object
            interactions when opening drawers (when simplified, objects within a drawer do not slide around on
            their own when the drawer is opened or closed, instead they are effectively glued down).
        """

        self._start_player_screen_width = player_screen_width
        self._start_player_screen_height = player_screen_height
        self._local_thor_build = local_thor_build
        self.x_display = x_display
        self.controller: Optional[Controller] = None
        self._started = False
        self._quality = quality

        self._initially_reachable_points: Optional[List[Dict]] = None
        self._initially_reachable_points_set: Optional[Set[Tuple[
            float, float]]] = None
        self._move_mag: Optional[float] = None
        self._grid_size: Optional[float] = None
        self._visibility_distance = visibility_distance
        self._fov = fov
        self.restrict_to_initially_reachable_points = (
            restrict_to_initially_reachable_points)
        self.make_agents_visible = make_agents_visible
        self.object_open_speed = object_open_speed
        self._always_return_visible_range = False
        self.simplify_physics = simplify_physics

        self.start(None)
        # noinspection PyTypeHints
        self.controller.docker_enabled = docker_enabled  # type: ignore

    @property
    def scene_name(self) -> str:
        """Current ai2thor scene."""
        return self.controller.last_event.metadata["sceneName"]

    @property
    def current_frame(self) -> np.ndarray:
        """Returns rgb image corresponding to the agent's egocentric view."""
        return self.controller.last_event.frame

    @property
    def last_event(self) -> ai2thor.server.Event:
        """Last event returned by the controller."""
        return self.controller.last_event

    @property
    def started(self) -> bool:
        """Has the ai2thor controller been started."""
        return self._started

    @property
    def last_action(self) -> str:
        """Last action, as a string, taken by the agent."""
        return self.controller.last_event.metadata["lastAction"]

    @last_action.setter
    def last_action(self, value: str) -> None:
        """Set the last action taken by the agent.

        Doing this is rewriting history, be careful.
        """
        self.controller.last_event.metadata["lastAction"] = value

    @property
    def last_action_success(self) -> bool:
        """Was the last action taken by the agent a success?"""
        return self.controller.last_event.metadata["lastActionSuccess"]

    @last_action_success.setter
    def last_action_success(self, value: bool) -> None:
        """Set whether or not the last action taken by the agent was a success.

        Doing this is rewriting history, be careful.
        """
        self.controller.last_event.metadata["lastActionSuccess"] = value

    @property
    def last_action_return(self) -> Any:
        """Get the value returned by the last action (if applicable).

        For an example of an action that returns a value, see
        `"GetReachablePositions"`.
        """
        return self.controller.last_event.metadata["actionReturn"]

    @last_action_return.setter
    def last_action_return(self, value: Any) -> None:
        """Set the value returned by the last action.

        Doing this is rewriting history, be careful.
        """
        self.controller.last_event.metadata["actionReturn"] = value

    def start(
        self,
        scene_name: Optional[str],
        move_mag: float = 0.25,
        **kwargs,
    ) -> None:
        """Starts the ai2thor controller if it was previously stopped.

        After starting, `reset` will be called with the scene name and move magnitude.

        # Parameters

        scene_name : The scene to load.
        move_mag : The amount of distance the agent moves in a single `MoveAhead` step.
        kwargs : additional kwargs, passed to reset.
        """
        if self._started:
            raise RuntimeError(
                "Trying to start the environment but it is already started.")

        self.controller = Controller(
            x_display=self.x_display,
            width=self._start_player_screen_width,
            height=self._start_player_screen_height,
            local_executable_path=self._local_thor_build,
            quality=self._quality,
            server_class=ai2thor.fifo_server.FifoServer,
        )

        if (
                self._start_player_screen_height,
                self._start_player_screen_width,
        ) != self.current_frame.shape[:2]:
            self.controller.step({
                "action": "ChangeResolution",
                "x": self._start_player_screen_width,
                "y": self._start_player_screen_height,
            })

        self._started = True
        self.reset(scene_name=scene_name, move_mag=move_mag, **kwargs)

    def stop(self) -> None:
        """Stops the ai2thor controller."""
        try:
            self.controller.stop()
        except Exception as e:
            get_logger().warning(str(e))
        finally:
            self._started = False

    def reset(
        self,
        scene_name: Optional[str],
        move_mag: float = 0.25,
        **kwargs,
    ):
        """Resets the ai2thor in a new scene.

        Resets ai2thor into a new scene and initializes the scene/agents with
        prespecified settings (e.g. move magnitude).

        # Parameters

        scene_name : The scene to load.
        move_mag : The amount of distance the agent moves in a single `MoveAhead` step.
        kwargs : additional kwargs, passed to the controller "Initialize" action.
        """
        self._move_mag = move_mag
        self._grid_size = self._move_mag

        if scene_name is None:
            scene_name = self.controller.last_event.metadata["sceneName"]
        self.controller.reset(scene_name)

        self.controller.step({
            "action": "Initialize",
            "gridSize": self._grid_size,
            "visibilityDistance": self._visibility_distance,
            "fov": self._fov,
            "makeAgentsVisible": self.make_agents_visible,
            "alwaysReturnVisibleRange": self._always_return_visible_range,
            **kwargs,
        })

        if self.object_open_speed != 1.0:
            self.controller.step({
                "action": "ChangeOpenSpeed",
                "x": self.object_open_speed
            })

        self._initially_reachable_points = None
        self._initially_reachable_points_set = None
        self.controller.step({"action": "GetReachablePositions"})
        if not self.controller.last_event.metadata["lastActionSuccess"]:
            get_logger().warning(
                "Error when getting reachable points: {}".format(
                    self.controller.last_event.metadata["errorMessage"]))
        self._initially_reachable_points = self.last_action_return

    def teleport_agent_to(
        self,
        x: float,
        y: float,
        z: float,
        rotation: float,
        horizon: float,
        standing: Optional[bool] = None,
        force_action: bool = False,
        only_initially_reachable: Optional[bool] = None,
        verbose=True,
        ignore_y_diffs=False,
    ) -> None:
        """Helper function teleporting the agent to a given location."""
        if standing is None:
            standing = self.last_event.metadata.get(
                "isStanding",
                self.last_event.metadata["agent"].get("isStanding"))
        original_location = self.get_agent_location()
        target = {"x": x, "y": y, "z": z}
        if only_initially_reachable is None:
            only_initially_reachable = self.restrict_to_initially_reachable_points
        if only_initially_reachable:
            reachable_points = self.initially_reachable_points
            reachable = False
            for p in reachable_points:
                if self.position_dist(target, p,
                                      ignore_y=ignore_y_diffs) < 0.01:
                    reachable = True
                    break
            if not reachable:
                self.last_action = "TeleportFull"
                self.last_event.metadata[
                    "errorMessage"] = "Target position was not initially reachable."
                self.last_action_success = False
                return
        self.controller.step(
            dict(
                action="TeleportFull",
                x=x,
                y=y,
                z=z,
                rotation={
                    "x": 0.0,
                    "y": rotation,
                    "z": 0.0
                },
                horizon=horizon,
                standing=standing,
                forceAction=force_action,
            ))
        if not self.last_action_success:
            agent_location = self.get_agent_location()
            rot_diff = (agent_location["rotation"] -
                        original_location["rotation"]) % 360
            new_old_dist = self.position_dist(original_location,
                                              agent_location,
                                              ignore_y=ignore_y_diffs)
            if (self.position_dist(original_location,
                                   agent_location,
                                   ignore_y=ignore_y_diffs) > 1e-2
                    or min(rot_diff, 360 - rot_diff) > 1):
                get_logger().warning(
                    "Teleportation FAILED but agent still moved (position_dist {}, rot diff {})"
                    " (\nprevious location\n{}\ncurrent_location\n{}\n)".
                    format(new_old_dist, rot_diff, original_location,
                           agent_location))
            return

        if force_action:
            assert self.last_action_success
            return

        agent_location = self.get_agent_location()
        rot_diff = (agent_location["rotation"] - rotation) % 360
        if (self.position_dist(agent_location, target, ignore_y=ignore_y_diffs)
                > 1e-2 or min(rot_diff, 360 - rot_diff) > 1):
            if only_initially_reachable:
                self._snap_agent_to_initially_reachable(verbose=False)
            if verbose:
                get_logger().warning(
                    "Teleportation did not place agent"
                    " precisely where desired in scene {}"
                    " (\ndesired\n{}\nactual\n{}\n)"
                    " perhaps due to grid snapping."
                    " Action is considered failed but agent may have moved.".
                    format(
                        self.scene_name,
                        {
                            "x": x,
                            "y": y,
                            "z": z,
                            "rotation": rotation,
                            "standing": standing,
                            "horizon": horizon,
                        },
                        agent_location,
                    ))
            self.last_action_success = False
        return

    def random_reachable_state(self, seed: int = None) -> Dict:
        """Returns a random reachable location in the scene."""
        if seed is not None:
            random.seed(seed)
        xyz = random.choice(self.currently_reachable_points)
        rotation = random.choice([0, 90, 180, 270])
        horizon = random.choice([0, 30, 60, 330])
        state = copy.copy(xyz)
        state["rotation"] = rotation
        state["horizon"] = horizon
        return state

    def randomize_agent_location(
            self,
            seed: int = None,
            partial_position: Optional[Dict[str, float]] = None) -> Dict:
        """Teleports the agent to a random reachable location in the scene."""
        if partial_position is None:
            partial_position = {}
        k = 0
        state: Optional[Dict] = None

        while k == 0 or (not self.last_action_success and k < 10):
            state = self.random_reachable_state(seed=seed)
            self.teleport_agent_to(**{**state, **partial_position})
            k += 1

        if not self.last_action_success:
            get_logger().warning(
                ("Randomize agent location in scene {}"
                 " with seed {} and partial position {} failed in "
                 "10 attempts. Forcing the action.").format(
                     self.scene_name, seed, partial_position))
            self.teleport_agent_to(**{
                **state,
                **partial_position
            },
                                   force_action=True)  # type: ignore
            assert self.last_action_success

        assert state is not None
        return state

    def object_pixels_in_frame(self,
                               object_id: str,
                               hide_all: bool = True,
                               hide_transparent: bool = False) -> np.ndarray:
        """Return an mask for a given object in the agent's current view.

        # Parameters

        object_id : The id of the object.
        hide_all : Whether or not to hide all other objects in the scene before getting the mask.
        hide_transparent : Whether or not partially transparent objects are considered to occlude the object.

        # Returns

        A numpy array of the mask.
        """

        # Emphasizing an object turns it magenta and hides all other objects
        # from view, we can find where the hand object is on the screen by
        # emphasizing it and then scanning across the image for the magenta pixels.
        if hide_all:
            self.step({"action": "EmphasizeObject", "objectId": object_id})
        else:
            self.step({"action": "MaskObject", "objectId": object_id})
            if hide_transparent:
                self.step({"action": "HideTranslucentObjects"})
        # noinspection PyShadowingBuiltins
        filter = np.array([[[255, 0, 255]]])
        object_pixels = 1 * np.all(self.current_frame == filter, axis=2)
        if hide_all:
            self.step({"action": "UnemphasizeAll"})
        else:
            self.step({"action": "UnmaskObject", "objectId": object_id})
            if hide_transparent:
                self.step({"action": "UnhideAllObjects"})
        return object_pixels

    def object_pixels_on_grid(
        self,
        object_id: str,
        grid_shape: Tuple[int, int],
        hide_all: bool = True,
        hide_transparent: bool = False,
    ) -> np.ndarray:
        """Like `object_pixels_in_frame` but counts object pixels in a
        partitioning of the image."""
        def partition(n, num_parts):
            m = n // num_parts
            parts = [m] * num_parts
            num_extra = n % num_parts
            for k in range(num_extra):
                parts[k] += 1
            return parts

        object_pixels = self.object_pixels_in_frame(
            object_id=object_id,
            hide_all=hide_all,
            hide_transparent=hide_transparent)

        # Divide the current frame into a grid and count the number
        # of hand object pixels in each of the grid squares
        sums_in_blocks: List[List] = []
        frame_shape = self.current_frame.shape[:2]
        row_inds = np.cumsum([0] + partition(frame_shape[0], grid_shape[0]))
        col_inds = np.cumsum([0] + partition(frame_shape[1], grid_shape[1]))
        for i in range(len(row_inds) - 1):
            sums_in_blocks.append([])
            for j in range(len(col_inds) - 1):
                sums_in_blocks[i].append(
                    np.sum(object_pixels[row_inds[i]:row_inds[i + 1],
                                         col_inds[j]:col_inds[j + 1]]))
        return np.array(sums_in_blocks, dtype=np.float32)

    def object_in_hand(self):
        """Object metadata for the object in the agent's hand."""
        inv_objs = self.last_event.metadata["inventoryObjects"]
        if len(inv_objs) == 0:
            return None
        elif len(inv_objs) == 1:
            return self.get_object_by_id(
                self.last_event.metadata["inventoryObjects"][0]["objectId"])
        else:
            raise AttributeError("Must be <= 1 inventory objects.")

    @property
    def initially_reachable_points(self) -> List[Dict[str, float]]:
        """List of {"x": x, "y": y, "z": z} locations in the scene that were
        reachable after initially resetting."""
        assert self._initially_reachable_points is not None
        return copy.deepcopy(self._initially_reachable_points)  # type:ignore

    @property
    def initially_reachable_points_set(self) -> Set[Tuple[float, float]]:
        """Set of (x,z) locations in the scene that were reachable after
        initially resetting."""
        if self._initially_reachable_points_set is None:
            self._initially_reachable_points_set = set()
            for p in self.initially_reachable_points:
                self._initially_reachable_points_set.add(
                    self._agent_location_to_tuple(p))

        return self._initially_reachable_points_set

    @property
    def currently_reachable_points(self) -> List[Dict[str, float]]:
        """List of {"x": x, "y": y, "z": z} locations in the scene that are
        currently reachable."""
        self.step({"action": "GetReachablePositions"})
        return self.last_event.metadata["actionReturn"]  # type:ignore

    def get_agent_location(self) -> Dict[str, Union[float, bool]]:
        """Gets agent's location."""
        metadata = self.controller.last_event.metadata
        location = {
            "x":
            metadata["agent"]["position"]["x"],
            "y":
            metadata["agent"]["position"]["y"],
            "z":
            metadata["agent"]["position"]["z"],
            "rotation":
            metadata["agent"]["rotation"]["y"],
            "horizon":
            metadata["agent"]["cameraHorizon"],
            "standing":
            metadata.get("isStanding", metadata["agent"].get("isStanding")),
        }
        return location

    @staticmethod
    def _agent_location_to_tuple(p: Dict[str, float]) -> Tuple[float, float]:
        return round(p["x"], 2), round(p["z"], 2)

    def _snap_agent_to_initially_reachable(self, verbose=True):
        agent_location = self.get_agent_location()

        end_location_tuple = self._agent_location_to_tuple(agent_location)
        if end_location_tuple in self.initially_reachable_points_set:
            return

        agent_x = agent_location["x"]
        agent_z = agent_location["z"]

        closest_reachable_points = list(self.initially_reachable_points_set)
        closest_reachable_points = sorted(
            closest_reachable_points,
            key=lambda xz: abs(xz[0] - agent_x) + abs(xz[1] - agent_z),
        )

        # In rare cases end_location_tuple might be not considered to be in self.initially_reachable_points_set
        # even when it is, here we check for such cases.
        if (math.sqrt(((np.array(closest_reachable_points[0]) -
                        np.array(end_location_tuple))**2).sum()) < 1e-6):
            return

        saved_last_action = self.last_action
        saved_last_action_success = self.last_action_success
        saved_last_action_return = self.last_action_return
        saved_error_message = self.last_event.metadata["errorMessage"]

        # Thor behaves weirdly when the agent gets off of the grid and you
        # try to teleport the agent back to the closest grid location. To
        # get around this we first teleport the agent to random location
        # and then back to where it should be.
        for point in self.initially_reachable_points:
            if abs(agent_x - point["x"]) > 0.1 or abs(agent_z -
                                                      point["z"]) > 0.1:
                self.teleport_agent_to(
                    rotation=0,
                    horizon=30,
                    **point,
                    only_initially_reachable=False,
                    verbose=False,
                )
                if self.last_action_success:
                    break

        for p in closest_reachable_points:
            self.teleport_agent_to(
                **{
                    **agent_location, "x": p[0],
                    "z": p[1]
                },
                only_initially_reachable=False,
                verbose=False,
            )
            if self.last_action_success:
                break

        teleport_forced = False
        if not self.last_action_success:
            self.teleport_agent_to(
                **{
                    **agent_location,
                    "x": closest_reachable_points[0][0],
                    "z": closest_reachable_points[0][1],
                },
                force_action=True,
                only_initially_reachable=False,
                verbose=False,
            )
            teleport_forced = True

        self.last_action = saved_last_action
        self.last_action_success = saved_last_action_success
        self.last_action_return = saved_last_action_return
        self.last_event.metadata["errorMessage"] = saved_error_message
        new_agent_location = self.get_agent_location()
        if verbose:
            get_logger().warning((
                "In {}, at location (x,z)=({},{}) which is not in the set "
                "of initially reachable points;"
                " attempting to correct this: agent teleported to (x,z)=({},{}).\n"
                "Teleportation {} forced.").format(
                    self.scene_name,
                    agent_x,
                    agent_z,
                    new_agent_location["x"],
                    new_agent_location["z"],
                    "was" if teleport_forced else "wasn't",
                ))

    def step(
        self,
        action_dict: Optional[Dict[str, Union[str, int, float, Dict]]] = None,
        **kwargs: Union[str, int, float, Dict],
    ) -> ai2thor.server.Event:
        """Take a step in the ai2thor environment."""
        if action_dict is None:
            action_dict = dict()
        action_dict.update(kwargs)

        action = cast(str, action_dict["action"])

        skip_render = "renderImage" in action_dict and not action_dict[
            "renderImage"]
        last_frame: Optional[np.ndarray] = None
        if skip_render:
            last_frame = self.current_frame

        if self.simplify_physics:
            action_dict["simplifyOPhysics"] = True

        if "Move" in action and "Hand" not in action:  # type: ignore
            action_dict = {
                **action_dict,
                "moveMagnitude": self._move_mag,
            }  # type: ignore
            start_location = self.get_agent_location()
            sr = self.controller.step(action_dict)

            if self.restrict_to_initially_reachable_points:
                end_location_tuple = self._agent_location_to_tuple(
                    self.get_agent_location())
                if end_location_tuple not in self.initially_reachable_points_set:
                    self.teleport_agent_to(**start_location,
                                           force_action=True)  # type: ignore
                    self.last_action = action
                    self.last_action_success = False
                    self.last_event.metadata[
                        "errorMessage"] = "Moved to location outside of initially reachable points."
        elif "RandomizeHideSeekObjects" in action:
            last_position = self.get_agent_location()
            self.controller.step(action_dict)
            metadata = self.last_event.metadata
            if self.position_dist(last_position,
                                  self.get_agent_location()) > 0.001:
                self.teleport_agent_to(**last_position,
                                       force_action=True)  # type: ignore
                get_logger().warning(
                    "In scene {}, after randomization of hide and seek objects, agent moved."
                    .format(self.scene_name))

            sr = self.controller.step({"action": "GetReachablePositions"})
            self._initially_reachable_points = self.controller.last_event.metadata[
                "actionReturn"]
            self._initially_reachable_points_set = None
            self.last_action = action
            self.last_action_success = metadata["lastActionSuccess"]
            self.controller.last_event.metadata["actionReturn"] = []
        elif "RotateUniverse" in action:
            sr = self.controller.step(action_dict)
            metadata = self.last_event.metadata

            if metadata["lastActionSuccess"]:
                sr = self.controller.step({"action": "GetReachablePositions"})
                self._initially_reachable_points = self.controller.last_event.metadata[
                    "actionReturn"]
                self._initially_reachable_points_set = None
                self.last_action = action
                self.last_action_success = metadata["lastActionSuccess"]
                self.controller.last_event.metadata["actionReturn"] = []
        else:
            sr = self.controller.step(action_dict)

        if self.restrict_to_initially_reachable_points:
            self._snap_agent_to_initially_reachable()

        if skip_render:
            assert last_frame is not None
            self.last_event.frame = last_frame

        return sr

    @staticmethod
    def position_dist(
        p0: Mapping[str, Any],
        p1: Mapping[str, Any],
        ignore_y: bool = False,
        l1_dist: bool = False,
    ) -> float:
        """Distance between two points of the form {"x": x, "y":y, "z":z"}."""
        if l1_dist:
            return (abs(p0["x"] - p1["x"]) +
                    (0 if ignore_y else abs(p0["y"] - p1["y"])) +
                    abs(p0["z"] - p1["z"]))
        else:
            return math.sqrt((p0["x"] - p1["x"])**2 +
                             (0 if ignore_y else (p0["y"] - p1["y"])**2) +
                             (p0["z"] - p1["z"])**2)

    @staticmethod
    def rotation_dist(a: Dict[str, float], b: Dict[str, float]):
        """Distance between rotations."""
        def deg_dist(d0: float, d1: float):
            dist = (d0 - d1) % 360
            return min(dist, 360 - dist)

        return sum(deg_dist(a[k], b[k]) for k in ["x", "y", "z"])

    @staticmethod
    def angle_between_rotations(a: Dict[str, float], b: Dict[str, float]):
        return np.abs(
            (180 / (2 * math.pi)) *
            (Rotation.from_euler("xyz", [a[k] for k in "xyz"], degrees=True) *
             Rotation.from_euler("xyz", [b[k] for k in "xyz"],
                                 degrees=True).inv()).as_rotvec()).sum()

    def closest_object_with_properties(
            self, properties: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Find the object closest to the agent that has the given
        properties."""
        agent_pos = self.controller.last_event.metadata["agent"]["position"]
        min_dist = float("inf")
        closest = None
        for o in self.all_objects():
            satisfies_all = True
            for k, v in properties.items():
                if o[k] != v:
                    satisfies_all = False
                    break
            if satisfies_all:
                d = self.position_dist(agent_pos, o["position"])
                if d < min_dist:
                    min_dist = d
                    closest = o
        return closest

    def closest_visible_object_of_type(
            self, object_type: str) -> Optional[Dict[str, Any]]:
        """Find the object closest to the agent that is visible and has the
        given type."""
        properties = {"visible": True, "objectType": object_type}
        return self.closest_object_with_properties(properties)

    def closest_object_of_type(self,
                               object_type: str) -> Optional[Dict[str, Any]]:
        """Find the object closest to the agent that has the given type."""
        properties = {"objectType": object_type}
        return self.closest_object_with_properties(properties)

    def closest_reachable_point_to_position(
            self, position: Dict[str,
                                 float]) -> Tuple[Dict[str, float], float]:
        """Of all reachable positions, find the one that is closest to the
        given location."""
        target = np.array([position["x"], position["z"]])
        min_dist = float("inf")
        closest_point = None
        for pt in self.initially_reachable_points:
            dist = np.linalg.norm(target - np.array([pt["x"], pt["z"]]))
            if dist < min_dist:
                closest_point = pt
                min_dist = dist
                if min_dist < 1e-3:
                    break
        assert closest_point is not None
        return closest_point, min_dist

    @staticmethod
    def _angle_from_to(a_from: float, a_to: float) -> float:
        a_from = a_from % 360
        a_to = a_to % 360
        min_rot = min(a_from, a_to)
        max_rot = max(a_from, a_to)
        rot_across_0 = (360 - max_rot) + min_rot
        rot_not_across_0 = max_rot - min_rot
        rot_err = min(rot_across_0, rot_not_across_0)
        if rot_across_0 == rot_err:
            rot_err *= -1 if a_to > a_from else 1
        else:
            rot_err *= 1 if a_to > a_from else -1
        return rot_err

    def agent_xz_to_scene_xz(self, agent_xz: Dict[str,
                                                  float]) -> Dict[str, float]:
        agent_pos = self.get_agent_location()

        x_rel_agent = agent_xz["x"]
        z_rel_agent = agent_xz["z"]
        scene_x = agent_pos["x"]
        scene_z = agent_pos["z"]
        rotation = agent_pos["rotation"]
        if abs(rotation) < 1e-5:
            scene_x += x_rel_agent
            scene_z += z_rel_agent
        elif abs(rotation - 90) < 1e-5:
            scene_x += z_rel_agent
            scene_z += -x_rel_agent
        elif abs(rotation - 180) < 1e-5:
            scene_x += -x_rel_agent
            scene_z += -z_rel_agent
        elif abs(rotation - 270) < 1e-5:
            scene_x += -z_rel_agent
            scene_z += x_rel_agent
        else:
            raise Exception("Rotation must be one of 0, 90, 180, or 270.")

        return {"x": scene_x, "z": scene_z}

    def scene_xz_to_agent_xz(self, scene_xz: Dict[str,
                                                  float]) -> Dict[str, float]:
        agent_pos = self.get_agent_location()
        x_err = scene_xz["x"] - agent_pos["x"]
        z_err = scene_xz["z"] - agent_pos["z"]

        rotation = agent_pos["rotation"]
        if abs(rotation) < 1e-5:
            agent_x = x_err
            agent_z = z_err
        elif abs(rotation - 90) < 1e-5:
            agent_x = -z_err
            agent_z = x_err
        elif abs(rotation - 180) < 1e-5:
            agent_x = -x_err
            agent_z = -z_err
        elif abs(rotation - 270) < 1e-5:
            agent_x = z_err
            agent_z = -x_err
        else:
            raise Exception("Rotation must be one of 0, 90, 180, or 270.")

        return {"x": agent_x, "z": agent_z}

    def all_objects(self) -> List[Dict[str, Any]]:
        """Return all object metadata."""
        return self.controller.last_event.metadata["objects"]

    def all_objects_with_properties(
            self, properties: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Find all objects with the given properties."""
        objects = []
        for o in self.all_objects():
            satisfies_all = True
            for k, v in properties.items():
                if o[k] != v:
                    satisfies_all = False
                    break
            if satisfies_all:
                objects.append(o)
        return objects

    def visible_objects(self) -> List[Dict[str, Any]]:
        """Return all visible objects."""
        return self.all_objects_with_properties({"visible": True})

    def get_object_by_id(self, object_id: str) -> Optional[Dict[str, Any]]:
        for o in self.last_event.metadata["objects"]:
            if o["objectId"] == object_id:
                return o
        return None

    ###
    # Following is used for computing shortest paths between states
    ###
    _CACHED_GRAPHS: Dict[str, nx.DiGraph] = {}

    GRAPH_ACTIONS_SET = {
        "LookUp", "LookDown", "RotateLeft", "RotateRight", "MoveAhead"
    }

    def reachable_points_with_rotations_and_horizons(self):
        self.controller.step({"action": "GetReachablePositions"})
        assert self.last_action_success

        points_slim = self.last_event.metadata["actionReturn"]

        points = []
        for r in [0, 90, 180, 270]:
            for horizon in [-30, 0, 30, 60]:
                for p in points_slim:
                    p = copy.copy(p)
                    p["rotation"] = r
                    p["horizon"] = horizon
                    points.append(p)
        return points

    @staticmethod
    def location_for_key(key, y_value=0.0):
        x, z, rot, hor = key
        loc = dict(x=x, y=y_value, z=z, rotation=rot, horizon=hor)
        return loc

    @staticmethod
    def get_key(input_dict: Dict[str, Any]) -> Tuple[float, float, int, int]:
        if "x" in input_dict:
            x = input_dict["x"]
            z = input_dict["z"]
            rot = input_dict["rotation"]
            hor = input_dict["horizon"]
        else:
            x = input_dict["position"]["x"]
            z = input_dict["position"]["z"]
            rot = input_dict["rotation"]["y"]
            hor = input_dict["cameraHorizon"]

        return (
            round(x, 2),
            round(z, 2),
            round_to_factor(rot, 90) % 360,
            round_to_factor(hor, 30) % 360,
        )

    def update_graph_with_failed_action(self, failed_action: str):
        if (self.scene_name not in self._CACHED_GRAPHS
                or failed_action not in self.GRAPH_ACTIONS_SET):
            return

        source_key = self.get_key(self.last_event.metadata["agent"])
        self._check_contains_key(source_key)

        edge_dict = self.graph[source_key]
        to_remove_key = None
        for target_key in self.graph[source_key]:
            if edge_dict[target_key]["action"] == failed_action:
                to_remove_key = target_key
                break
        if to_remove_key is not None:
            self.graph.remove_edge(source_key, to_remove_key)

    def _add_from_to_edge(
        self,
        g: nx.DiGraph,
        s: Tuple[float, float, int, int],
        t: Tuple[float, float, int, int],
    ):
        def ae(x, y):
            return abs(x - y) < 0.001

        s_x, s_z, s_rot, s_hor = s
        t_x, t_z, t_rot, t_hor = t

        dist = round(math.sqrt((s_x - t_x)**2 + (s_z - t_z)**2), 2)
        angle_dist = (round_to_factor(t_rot - s_rot, 90) % 360) // 90
        horz_dist = (round_to_factor(t_hor - s_hor, 30) % 360) // 30

        # If source and target differ by more than one action, continue
        if sum(x != 0 for x in [dist, angle_dist, horz_dist]) != 1:
            return

        grid_size = self._grid_size
        action = None
        if angle_dist != 0:
            if angle_dist == 1:
                action = "RotateRight"
            elif angle_dist == 3:
                action = "RotateLeft"

        elif horz_dist != 0:
            if horz_dist == 11:
                action = "LookUp"
            elif horz_dist == 1:
                action = "LookDown"
        elif ae(dist, grid_size):
            if ((s_rot == 0 and ae(t_z - s_z, grid_size))
                    or (s_rot == 90 and ae(t_x - s_x, grid_size))
                    or (s_rot == 180 and ae(t_z - s_z, -grid_size))
                    or (s_rot == 270 and ae(t_x - s_x, -grid_size))):
                g.add_edge(s, t, action="MoveAhead")

        if action is not None:
            g.add_edge(s, t, action=action)

    @functools.lru_cache(1)
    def possible_neighbor_offsets(
            self) -> Tuple[Tuple[float, float, int, int], ...]:
        grid_size = round(self._grid_size, 2)
        offsets = []
        for rot_diff in [-90, 0, 90]:
            for horz_diff in [-30, 0, 30, 60]:
                for x_diff in [-grid_size, 0, grid_size]:
                    for z_diff in [-grid_size, 0, grid_size]:
                        if (rot_diff != 0) + (horz_diff != 0) + (
                                x_diff != 0) + (z_diff != 0) == 1:
                            offsets.append(
                                (x_diff, z_diff, rot_diff, horz_diff))
        return tuple(offsets)

    def _add_node_to_graph(self, graph: nx.DiGraph, s: Tuple[float, float, int,
                                                             int]):
        if s in graph:
            return

        existing_nodes = set(graph.nodes())
        graph.add_node(s)

        for o in self.possible_neighbor_offsets():
            t = (s[0] + o[0], s[1] + o[1], s[2] + o[2], s[3] + o[3])
            if t in existing_nodes:
                self._add_from_to_edge(graph, s, t)
                self._add_from_to_edge(graph, t, s)

    @property
    def graph(self):
        if self.scene_name not in self._CACHED_GRAPHS:
            g = nx.DiGraph()
            points = self.reachable_points_with_rotations_and_horizons()
            for p in points:
                self._add_node_to_graph(g, self.get_key(p))

            self._CACHED_GRAPHS[self.scene_name] = g
        return self._CACHED_GRAPHS[self.scene_name]

    @graph.setter
    def graph(self, g):
        self._CACHED_GRAPHS[self.scene_name] = g

    def _check_contains_key(self,
                            key: Tuple[float, float, int, int],
                            add_if_not=True):
        if key not in self.graph:
            get_logger().warning(
                "{} was not in the graph for scene {}.".format(
                    key, self.scene_name))
            if add_if_not:
                self._add_node_to_graph(self.graph, key)

    def shortest_state_path(self, source_state_key, goal_state_key):
        self._check_contains_key(source_state_key)
        self._check_contains_key(goal_state_key)
        # noinspection PyBroadException
        try:
            path = nx.shortest_path(self.graph, source_state_key,
                                    goal_state_key)
            return path
        except Exception as _:
            return None

    def action_transitioning_between_keys(self, s, t):
        self._check_contains_key(s)
        self._check_contains_key(t)
        if self.graph.has_edge(s, t):
            return self.graph.get_edge_data(s, t)["action"]
        else:
            return None

    def shortest_path_next_state(self, source_state_key, goal_state_key):
        self._check_contains_key(source_state_key)
        self._check_contains_key(goal_state_key)
        if source_state_key == goal_state_key:
            raise RuntimeError(
                "called next state on the same source and goal state")
        state_path = self.shortest_state_path(source_state_key, goal_state_key)
        return state_path[1]

    def shortest_path_next_action(self, source_state_key, goal_state_key):
        self._check_contains_key(source_state_key)
        self._check_contains_key(goal_state_key)

        next_state_key = self.shortest_path_next_state(source_state_key,
                                                       goal_state_key)
        return self.graph.get_edge_data(source_state_key,
                                        next_state_key)["action"]

    def shortest_path_length(self, source_state_key, goal_state_key):
        self._check_contains_key(source_state_key)
        self._check_contains_key(goal_state_key)
        try:
            return nx.shortest_path_length(self.graph, source_state_key,
                                           goal_state_key)
        except nx.NetworkXNoPath as _:
            return float("inf")
Пример #6
0
class RoboThorEnvironment:
    """Wrapper for the robo2thor controller providing additional functionality
    and bookkeeping.

    See [here](https://ai2thor.allenai.org/robothor/documentation) for comprehensive
     documentation on RoboTHOR.

    # Attributes

    controller : The AI2-THOR controller.
    config : The AI2-THOR controller configuration
    """
    def __init__(self, all_metadata_available: bool = True, **kwargs):
        self.config = dict(
            rotateStepDegrees=30.0,
            visibilityDistance=1.0,
            gridSize=0.25,
            continuousMode=True,
            snapToGrid=False,
            agentMode="locobot",
            width=640,
            height=480,
            agentCount=1,
            server_class=FifoServer,
        )

        if "agentCount" in kwargs:
            assert kwargs["agentCount"] > 0

        kwargs["agentMode"] = kwargs.get("agentMode", "locobot")
        if kwargs["agentMode"] not in ["bot", "locobot"]:
            warnings.warn(f"The RoboTHOR environment has not been tested using"
                          f" an agent of mode '{kwargs['agentMode']}'.")

        recursive_update(self.config, kwargs)
        self.controller = Controller(**self.config, )

        self.all_metadata_available = all_metadata_available

        self.scene_to_reachable_positions: Optional[Dict[str, Any]] = None
        self.distance_cache: Optional[DynamicDistanceCache] = None

        if self.all_metadata_available:
            self.scene_to_reachable_positions = {
                self.scene_name: copy.deepcopy(self.currently_reachable_points)
            }
            assert len(self.scene_to_reachable_positions[self.scene_name]) > 10

            self.distance_cache = DynamicDistanceCache(rounding=1)

        self.agent_count = self.config["agentCount"]

        self._extra_teleport_kwargs: Dict[str, Any] = {
        }  # Used for backwards compatability with the teleport action

    def initialize_grid_dimensions(
        self, reachable_points: Collection[Dict[str, float]]
    ) -> Tuple[int, int, int, int]:
        """Computes bounding box for reachable points quantized with the
        current gridSize."""
        points = {(
            round(p["x"] / self.config["gridSize"]),
            round(p["z"] / self.config["gridSize"]),
        ): p
                  for p in reachable_points}

        assert len(reachable_points) == len(points)

        xmin, xmax = min([p[0] for p in points]), max([p[0] for p in points])
        zmin, zmax = min([p[1] for p in points]), max([p[1] for p in points])

        return xmin, xmax, zmin, zmax

    def set_object_filter(self, object_ids: List[str]):
        self.controller.step("SetObjectFilter",
                             objectIds=object_ids,
                             renderImage=False)

    def reset_object_filter(self):
        self.controller.step("ResetObjectFilter", renderImage=False)

    def path_from_point_to_object_type(
            self, point: Dict[str, float], object_type: str,
            allowed_error: float) -> Optional[List[Dict[str, float]]]:
        event = self.controller.step(
            action="GetShortestPath",
            objectType=object_type,
            position=point,
            allowedError=allowed_error,
        )
        if event.metadata["lastActionSuccess"]:
            return event.metadata["actionReturn"]["corners"]
        else:
            get_logger().debug(
                "Failed to find path for {} in {}. Start point {}, agent state {}."
                .format(
                    object_type,
                    self.controller.last_event.metadata["sceneName"],
                    point,
                    self.agent_state(),
                ))
            return None

    def distance_from_point_to_object_type(self, point: Dict[str, float],
                                           object_type: str,
                                           allowed_error: float) -> float:
        """Minimal geodesic distance from a point to an object of the given
        type.

        It might return -1.0 for unreachable targets.
        """
        path = self.path_from_point_to_object_type(point, object_type,
                                                   allowed_error)
        if path:
            # Because `allowed_error != 0` means that the path returned above might not start
            # at `point`, we explicitly add any offset there is.
            s_dist = math.sqrt((point["x"] - path[0]["x"])**2 +
                               (point["z"] - path[0]["z"])**2)
            return metrics.path_distance(path) + s_dist
        return -1.0

    def distance_to_object_type(self,
                                object_type: str,
                                agent_id: int = 0) -> float:
        """Minimal geodesic distance to object of given type from agent's
        current location.

        It might return -1.0 for unreachable targets.
        """
        assert 0 <= agent_id < self.agent_count
        assert (
            self.all_metadata_available
        ), "`distance_to_object_type` cannot be called when `self.all_metadata_available` is `False`."

        def retry_dist(position: Dict[str, float], object_type: str):
            allowed_error = 0.05
            debug_log = ""
            d = -1.0
            while allowed_error < 2.5:
                d = self.distance_from_point_to_object_type(
                    position, object_type, allowed_error)
                if d < 0:
                    debug_log = (
                        f"In scene {self.scene_name}, could not find a path from {position} to {object_type} with"
                        f" {allowed_error} error tolerance. Increasing this tolerance to"
                        f" {2 * allowed_error} any trying again.")
                    allowed_error *= 2
                else:
                    break
            if d < 0:
                get_logger().warning(
                    f"In scene {self.scene_name}, could not find a path from {position} to {object_type}"
                    f" with {allowed_error} error tolerance. Returning a distance of -1."
                )
            elif debug_log != "":
                get_logger().debug(debug_log)
            return d

        return self.distance_cache.find_distance(
            self.scene_name,
            self.controller.last_event.events[agent_id].metadata["agent"]
            ["position"],
            object_type,
            retry_dist,
        )

    def path_from_point_to_point(
            self, position: Dict[str, float], target: Dict[str, float],
            allowedError: float) -> Optional[List[Dict[str, float]]]:
        try:
            return self.controller.step(
                action="GetShortestPathToPoint",
                position=position,
                x=target["x"],
                y=target["y"],
                z=target["z"],
                allowedError=allowedError,
            ).metadata["actionReturn"]["corners"]
        except Exception:
            get_logger().debug(
                "Failed to find path for {} in {}. Start point {}, agent state {}."
                .format(
                    target,
                    self.controller.last_event.metadata["sceneName"],
                    position,
                    self.agent_state(),
                ))
            return None

    def distance_from_point_to_point(self, position: Dict[str, float],
                                     target: Dict[str, float],
                                     allowed_error: float) -> float:
        path = self.path_from_point_to_point(position, target, allowed_error)
        if path:
            # Because `allowed_error != 0` means that the path returned above might not start
            # or end exactly at the position/target points, we explictly add any offset there is.
            s_dist = math.sqrt((position["x"] - path[0]["x"])**2 +
                               (position["z"] - path[0]["z"])**2)
            t_dist = math.sqrt((target["x"] - path[-1]["x"])**2 +
                               (target["z"] - path[-1]["z"])**2)
            return metrics.path_distance(path) + s_dist + t_dist
        return -1.0

    def distance_to_point(self,
                          target: Dict[str, float],
                          agent_id: int = 0) -> float:
        """Minimal geodesic distance to end point from agent's current
        location.

        It might return -1.0 for unreachable targets.
        """
        assert 0 <= agent_id < self.agent_count
        assert (
            self.all_metadata_available
        ), "`distance_to_object_type` cannot be called when `self.all_metadata_available` is `False`."

        def retry_dist(position: Dict[str, float], target: Dict[str, float]):
            allowed_error = 0.05
            debug_log = ""
            d = -1.0
            while allowed_error < 2.5:
                d = self.distance_from_point_to_point(position, target,
                                                      allowed_error)
                if d < 0:
                    debug_log = (
                        f"In scene {self.scene_name}, could not find a path from {position} to {target} with"
                        f" {allowed_error} error tolerance. Increasing this tolerance to"
                        f" {2 * allowed_error} any trying again.")
                    allowed_error *= 2
                else:
                    break
            if d < 0:
                get_logger().warning(
                    f"In scene {self.scene_name}, could not find a path from {position} to {target}"
                    f" with {allowed_error} error tolerance. Returning a distance of -1."
                )
            elif debug_log != "":
                get_logger().debug(debug_log)
            return d

        return self.distance_cache.find_distance(
            self.scene_name,
            self.controller.last_event.events[agent_id].metadata["agent"]
            ["position"],
            target,
            retry_dist,
        )

    def agent_state(self, agent_id: int = 0) -> Dict:
        """Return agent position, rotation and horizon."""
        assert 0 <= agent_id < self.agent_count

        agent_meta = self.last_event.events[agent_id].metadata["agent"]
        return {
            **{k: float(v)
               for k, v in agent_meta["position"].items()},
            "rotation":
            {k: float(v)
             for k, v in agent_meta["rotation"].items()},
            "horizon": round(float(agent_meta["cameraHorizon"]), 1),
        }

    def teleport(
        self,
        pose: Dict[str, float],
        rotation: Dict[str, float],
        horizon: float = 0.0,
        agent_id: int = 0,
    ):
        assert 0 <= agent_id < self.agent_count
        try:
            e = self.controller.step(
                action="TeleportFull",
                x=pose["x"],
                y=pose["y"],
                z=pose["z"],
                rotation=rotation,
                horizon=horizon,
                agentId=agent_id,
                **self._extra_teleport_kwargs,
            )
        except ValueError as e:
            if len(self._extra_teleport_kwargs) == 0:
                self._extra_teleport_kwargs["standing"] = True
            else:
                raise e
            return self.teleport(pose=pose,
                                 rotation=rotation,
                                 horizon=horizon,
                                 agent_id=agent_id)
        return e.metadata["lastActionSuccess"]

    def reset(self,
              scene_name: str = None,
              filtered_objects: Optional[List[str]] = None) -> None:
        """Resets scene to a known initial state."""
        if scene_name is not None and scene_name != self.scene_name:
            self.controller.reset(scene_name)
            assert self.last_action_success, "Could not reset to new scene"

            if (self.all_metadata_available
                    and scene_name not in self.scene_to_reachable_positions):
                self.scene_to_reachable_positions[scene_name] = copy.deepcopy(
                    self.currently_reachable_points)
                assert len(self.scene_to_reachable_positions[scene_name]) > 10
        if filtered_objects:
            self.set_object_filter(filtered_objects)
        else:
            self.reset_object_filter()

    def random_reachable_state(
        self,
        seed: Optional[int] = None
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Returns a random reachable location in the scene."""
        assert (
            self.all_metadata_available
        ), "`random_reachable_state` cannot be called when `self.all_metadata_available` is `False`."

        if seed is not None:
            random.seed(seed)
        # xyz = random.choice(self.currently_reachable_points)
        assert len(self.scene_to_reachable_positions[self.scene_name]) > 10
        xyz = copy.deepcopy(
            random.choice(self.scene_to_reachable_positions[self.scene_name]))
        rotation = random.choice(
            np.arange(0.0, 360.0, self.config["rotateStepDegrees"]))
        horizon = 0.0  # random.choice([0.0, 30.0, 330.0])
        return {
            **{k: float(v)
               for k, v in xyz.items()},
            "rotation": {
                "x": 0.0,
                "y": float(rotation),
                "z": 0.0
            },
            "horizon": float(horizon),
        }

    def randomize_agent_location(
        self,
        seed: int = None,
        partial_position: Optional[Dict[str, float]] = None,
        agent_id: int = 0,
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Teleports the agent to a random reachable location in the scene."""
        assert 0 <= agent_id < self.agent_count

        if partial_position is None:
            partial_position = {}
        k = 0
        state: Optional[Dict] = None

        while k == 0 or (not self.last_action_success and k < 10):
            # self.reset()
            state = {
                **self.random_reachable_state(seed=seed),
                **partial_position
            }
            # get_logger().debug("picked target location {}".format(state))
            self.controller.step("TeleportFull", **state, agentId=agent_id)
            k += 1

        if not self.last_action_success:
            get_logger().warning((
                "Randomize agent location in scene {} and current random state {}"
                " with seed {} and partial position {} failed in "
                "10 attempts. Forcing the action.").format(
                    self.scene_name, state, seed, partial_position))
            self.controller.step("TeleportFull",
                                 **state,
                                 force_action=True,
                                 agentId=agent_id)  # type: ignore
            assert self.last_action_success, "Force action failed with {}".format(
                state)

        # get_logger().debug("location after teleport full {}".format(self.agent_state()))
        # self.controller.step("TeleportFull", **self.agent_state())  # TODO only for debug
        # get_logger().debug("location after re-teleport full {}".format(self.agent_state()))

        return self.agent_state(agent_id=agent_id)

    def known_good_locations_list(self):
        assert (
            self.all_metadata_available
        ), "`known_good_locations_list` cannot be called when `self.all_metadata_available` is `False`."
        return self.scene_to_reachable_positions[self.scene_name]

    @property
    def currently_reachable_points(self) -> List[Dict[str, float]]:
        """List of {"x": x, "y": y, "z": z} locations in the scene that are
        currently reachable."""
        self.controller.step(action="GetReachablePositions")
        assert (
            self.last_action_success
        ), f"Could not get reachable positions for reason {self.last_event.metadata['errorMessage']}."
        return self.last_action_return

    @property
    def scene_name(self) -> str:
        """Current ai2thor scene."""
        return self.controller.last_event.metadata["sceneName"].replace(
            "_physics", "")

    @property
    def current_frame(self) -> np.ndarray:
        """Returns rgb image corresponding to the agent's egocentric view."""
        return self.controller.last_event.frame

    @property
    def current_depth(self) -> np.ndarray:
        """Returns depth image corresponding to the agent's egocentric view."""
        return self.controller.last_event.depth_frame

    @property
    def current_frames(self) -> List[np.ndarray]:
        """Returns rgb images corresponding to the agents' egocentric views."""
        return [
            self.controller.last_event.events[agent_id].frame
            for agent_id in range(self.agent_count)
        ]

    @property
    def current_depths(self) -> List[np.ndarray]:
        """Returns depth images corresponding to the agents' egocentric
        views."""
        return [
            self.controller.last_event.events[agent_id].depth_frame
            for agent_id in range(self.agent_count)
        ]

    @property
    def last_event(self) -> ai2thor.server.Event:
        """Last event returned by the controller."""
        return self.controller.last_event

    @property
    def last_action(self) -> str:
        """Last action, as a string, taken by the agent."""
        return self.controller.last_event.metadata["lastAction"]

    @property
    def last_action_success(self) -> bool:
        """Was the last action taken by the agent a success?"""
        return self.controller.last_event.metadata["lastActionSuccess"]

    @property
    def last_action_return(self) -> Any:
        """Get the value returned by the last action (if applicable).

        For an example of an action that returns a value, see
        `"GetReachablePositions"`.
        """
        return self.controller.last_event.metadata["actionReturn"]

    def step(
        self,
        action_dict: Optional[Dict[str, Union[str, int, float, Dict]]] = None,
        **kwargs: Union[str, int, float, Dict],
    ) -> ai2thor.server.Event:
        """Take a step in the ai2thor environment."""
        if action_dict is None:
            action_dict = dict()
        action_dict.update(kwargs)

        return self.controller.step(**action_dict)

    def stop(self):
        """Stops the ai2thor controller."""
        try:
            self.controller.stop()
        except Exception as e:
            get_logger().warning(str(e))

    def all_objects(self) -> List[Dict[str, Any]]:
        """Return all object metadata."""
        return self.controller.last_event.metadata["objects"]

    def all_objects_with_properties(
            self, properties: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Find all objects with the given properties."""
        objects = []
        for o in self.all_objects():
            satisfies_all = True
            for k, v in properties.items():
                if o[k] != v:
                    satisfies_all = False
                    break
            if satisfies_all:
                objects.append(o)
        return objects

    def visible_objects(self) -> List[Dict[str, Any]]:
        """Return all visible objects."""
        return self.all_objects_with_properties({"visible": True})
Пример #7
0
class QueryAgentSAC():
    def __init__(self, scene, target, room_object_types = G_livingroom_objtype):
        # ai2thor
        self.scene = scene
        self.target = target
        assert self.target in room_object_types
        self.room_object_types = room_object_types
        self.controller = Controller(scene=scene, 
                        renderInstanceSegmentation=True,
                        width=1080,
                        height=1080)
        self.target_type = target

        # register query
        FrameInfo.candidates = self.room_object_types
        self.query_indicates = [True for _ in range(len(FrameInfo.candidates))]

        # Keep track of qa history 
        self.keep_frame_map = False # weather to keep history to avoid duplicated query
        self.history = deque(maxlen = 1000)
        # self.frame_map = {} # (position and rotation) -> FrameInfo

        # RL part
        self.observation = None
        self.last_action = None

        self.episode_done = False

        # reward
        self.time_penalty = -0.01
        self.action_fail_penalty = -0.01

        self.first_seen = False
        self.first_seen_reward = 1 # reward for finding the object initially

        self.first_in_range = False
        self.first_in_range_reward = 5.0 # object in interaction range

        self.mission_success_reward = 10.0 # object in interaction range and done

        # init event
        self.event = None
        self.step(5) # self.controller.step("Done") 
        self.observation = self.get_observation()

        # training
        self.use_gpu = True
        
        # learning
        self.batch_size = 4
        self.learning_rate = 0.001
        self.alpha = 1 # soc temparature
        self.gamma = 0.95 # discount factor
        self.soft_target_tau = 0.01
        self.target_update_period = 1 # updatge target network frequency

        # record
        self.n_train_steps_total = 0
        self.episode_steps = 0
        self.episode_total_reward = 0
        self.episode_policy_loss = []
        self.episode_pf1_loss = []
        self.episode_pf2_loss = []
        

        # policy network
        self.input_dim = len(self.observation)
        self.hidden_dim = 64
        self.action_dim = 6
        self.policy = PolicyNetwork(self.input_dim, self.hidden_dim, self.action_dim)

        self.qf1 = QNetwork(self.input_dim , self.action_dim, self.hidden_dim)
        self.qf2 = QNetwork(self.input_dim , self.action_dim, self.hidden_dim)
        self.target_qf1 = QNetwork(self.input_dim , self.action_dim, self.hidden_dim)
        self.target_qf2 = QNetwork(self.input_dim , self.action_dim, self.hidden_dim)
        
        if self.use_gpu:
            self.policy = self.policy.cuda()
            self.qf1 = self.qf1.cuda()
            self.qf2 = self.qf2.cuda()
            self.target_qf1 = self.target_qf1.cuda()
            self.target_qf2 = self.target_qf2.cuda()

        # loss
        self.qf_criterion = nn.MSELoss()

        self.update_target_networks()

        self.policy_optimizer = optim.Adam(
            self.policy.parameters(),
            lr=self.learning_rate,
        )
        self.qf1_optimizer = optim.Adam(
            self.qf1.parameters(),
            lr=self.learning_rate,
        )
        self.qf2_optimizer = optim.Adam(
            self.qf2.parameters(),
            lr=self.learning_rate,
        )

        self.update_target_networks()
    
    def update_target_networks(self):
        soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau)
        soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau)
    
    def take_action(self, epsilon = 0.2, print_action = False):
        current_state = self.observation

        if np.random.rand() < epsilon:
            action_code = np.random.randint(self.action_dim)
        else:
            current_state_tensor = torch.FloatTensor(current_state).unsqueeze(0)
            if self.use_gpu:
                current_state_tensor = current_state_tensor.to("cuda")
            action, _ = self.policy.sample_action_with_prob(current_state_tensor)
            
            action_code = torch.argmax(action, dim = -1)[0].item()
            #print(action_code)

        if print_action:
            print("Agent action: ", action_code)
        self.step(action_code)
        next_observation = self.get_observation()
        
        # calulate reward
        reward = self.time_penalty
        
        if not self.event.metadata["lastActionSuccess"]:
            reward += self.action_fail_penalty

        frame_info = FrameInfo(self.event)
        if not self.first_seen:
            for obj in frame_info.object_info:
                if self.target_type == obj["objectType"]:
                    reward += self.first_seen_reward
                    self.first_seen = True
                    break
        
        if not self.first_in_range:
            for obj in frame_info.object_info:
                if self.target_type == obj["objectType"] and obj["visible"] == True:
                    reward += self.first_in_range_reward
                    self.first_in_range = True
                    break
        

        for obj in frame_info.object_info:
            if self.target_type == obj["objectType"] and obj["visible"] == True:
                if action_code == 5:
                    reward += self.mission_success_reward
                    self.episode_done = True
                break

        self.history.append([self.observation.copy(), action_code, reward, next_observation.copy(), self.episode_done])
        
        self.observation = next_observation

        self.episode_steps += 1
        self.episode_total_reward += reward
        
    def step(self, action_code:int):
        self.last_action = action_code
        self.event = self.controller.step(G_action_code2action[action_code])

    def get_observation(self):
        '''
        Get state for RL
        '''
        state = []
        frame_info = FrameInfo(self.event)

        # target encode
        target_encode = [0 for _ in range(len(self.room_object_types))]
        target_index = self.room_object_types.index(self.target_type)
        target_encode[target_index] = 1
        state.extend(target_encode)

        # agent state: last action success encode
        last_action_success = 1 if self.event.metadata["lastActionSuccess"] else -1
        state.append(last_action_success)
        last_action_encode = [0] * len(G_action2code)
        last_action_encode[self.last_action] = 1
        state.extend(last_action_encode)

        # agent head position
        head_pose = round(self.event.metadata["agent"]['cameraHorizon']) // 30
        state.append(head_pose)

        # object state query
        # if self.keep_history
        frame_info = FrameInfo(self.event)
        obj_query_state = frame_info.get_answer_array_for_all_candidates(self.query_indicates)
        state.extend(obj_query_state)

        return state

    def sample_history(self, strategy="positive first"):
        all_indexes = np.arange(len(self.history))
        all_rewards = np.asarray([h[2] for h in self.history])
        sample_prob = softmax(all_rewards)
        sample_indexes = np.random.choice(all_indexes, size=self.batch_size, replace=True, p=sample_prob)

        return [self.history[index] for index in sample_indexes]

    def learn(self):
        # sample history
        # sample_list = random.sample((self.history), self.batch_size)
        sample_list = self.sample_history()

        s0 = [sample_list[i][0] for i in range(self.batch_size)]
        a = [[1 if j == sample_list[i][1] else 0 for j in range(self.action_dim)] for i in range(self.batch_size)]
        r = [[sample_list[i][2]] for i in range(self.batch_size)]
        s1 = [sample_list[i][3] for i in range(self.batch_size)]
        d = [[int(sample_list[i][4])] for i in range(self.batch_size)]
        
        s0 = torch.FloatTensor(s0)
        a = torch.LongTensor(a)
        r = torch.FloatTensor(r)
        s1 = torch.FloatTensor(s1)
        d = torch.FloatTensor(d)

        if self.use_gpu:
            s0 = s0.to("cuda")
            a = a.to("cuda")
            r = r.to("cuda")
            s1 = s1.to("cuda")
            d = d.to("cuda")

        """
        Policy loss
        """
        new_obs_actions, log_pi = self.policy.sample_action_with_prob(s0)
        log_pi = log_pi.unsqueeze(-1)

        q_new_actions = torch.min(
            self.qf1(s0, new_obs_actions),
            self.qf2(s0, new_obs_actions),
        )

        policy_loss = (self.alpha*log_pi - q_new_actions).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(s0, a)
        q2_pred = self.qf2(s0, a)
        new_next_actions, new_log_pi = self.policy.sample_action_with_prob(s1)
        new_log_pi = new_log_pi.unsqueeze(-1)

        target_q_values = torch.min(self.target_qf1(s1, new_next_actions),self.target_qf2(s1, new_next_actions)) - self.alpha * new_log_pi

        q_target = r + (1. - d) * self.gamma * target_q_values

        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())


        #update parameters
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.n_train_steps_total += 1
        if self.n_train_steps_total % self.target_update_period == 0:
            self.update_target_networks()

        # record
        self.episode_policy_loss.append(policy_loss.item())
        self.episode_pf1_loss.append(qf1_loss.item())
        self.episode_pf2_loss.append(qf2_loss.item())
        

    def reset_scene_and_target(self, scene: str, target: str):
        self.controller.reset(scene=scene)
        self.target_type = target

    def reset_episode(self):
        self.reset_scene_and_target(self.scene, self.target)
        self.episode_done = False
        self.episode_steps = 0
        self.episode_total_reward = 0
        self.episode_policy_loss.clear()
        self.episode_pf1_loss.clear()
        self.episode_pf2_loss.clear()
        self.first_seen = False
        self.first_in_range = False

    def close(self):
        self.controller.stop()

    def save_model(self):
        from datetime import datetime
        now = datetime.now()
        time_str = now.strftime("%H:%M:%S")
        torch.save(self.qf1.state_dict(), "record/qf1_" + time_str + ".pth")
        torch.save(self.qf2.state_dict(), "record/qf2_" + time_str + ".pth")
        torch.save(self.qf2.state_dict(), "record/policy_" + time_str + ".pth")
Пример #8
0
class QueryAgentSIL():
    def __init__(self, scene, target, room_object_types=G_livingroom_objtype):
        # ai2thor
        self.scene = scene
        self.target = target
        assert self.target in room_object_types
        self.room_object_types = room_object_types
        self.controller = Controller(scene=scene,
                                     renderInstanceSegmentation=True,
                                     width=1080,
                                     height=1080)
        self.target_type = target

        # register query
        FrameInfo.candidates = self.room_object_types
        self.query_indicates = [True for _ in range(len(FrameInfo.candidates))]

        # Keep track of qa history
        self.keep_frame_map = False  # weather to keep history to avoid duplicated query
        self.replay_buffer = deque(maxlen=1000)

        # self.frame_map = {} # (position and rotation) -> FrameInfo

        # RL part
        self.observation = None
        self.last_action = None

        self.episode_done = False
        self.episode_history = []

        # reward
        self.time_penalty = -0.01
        self.action_fail_penalty = -0.01

        self.first_seen = False
        self.first_seen_reward = 1  # reward for finding the object initially

        self.first_in_range = False
        self.first_in_range_reward = 5.0  # object in interaction range

        self.mission_success_reward = 10.0  # object in interaction range and done

        # init event
        self.event = None
        self.step(5)  # self.controller.step("Done")
        self.observation = self.get_observation()

        # training
        self.use_gpu = True

        # learning
        self.batch_size = 4
        self.learning_rate = 0.001
        self.alpha = 1  # soc temparature
        self.gamma = 0.95  # discount factor

        # record
        self.n_train_steps_total = 0
        self.episode_steps = 0
        self.episode_total_reward = 0

        # policy network
        self.state_dim = len(self.observation)
        self.hidden_dim = 64
        self.action_dim = 6
        self.policy = PolicyNetwork(self.state_dim, self.hidden_dim,
                                    self.action_dim)
        self.value_net = ValueNetwork(self.state_dim, self.action_dim)

        if self.use_gpu:
            self.policy = self.policy.cuda()
            self.value_net = self.value_net.cuda()

        # loss
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optim.Adam(
            self.policy.parameters(),
            lr=self.learning_rate,
        )
        self.value_optimizer = optim.Adam(
            self.value_net.parameters(),
            lr=self.learning_rate,
        )

    def step(self, action_code: int):
        self.last_action = action_code
        self.event = self.controller.step(G_action_code2action[action_code])

    def get_observation(self):
        '''
        Get state for RL
        '''
        state = []
        frame_info = FrameInfo(self.event)

        # target encode
        target_encode = [0 for _ in range(len(self.room_object_types))]
        target_index = self.room_object_types.index(self.target_type)
        target_encode[target_index] = 1
        state.extend(target_encode)

        # agent state: last action success encode
        last_action_success = 1 if self.event.metadata[
            "lastActionSuccess"] else -1
        state.append(last_action_success)
        last_action_encode = [0] * len(G_action2code)
        last_action_encode[self.last_action] = 1
        state.extend(last_action_encode)

        # agent head position
        head_pose = round(self.event.metadata["agent"]['cameraHorizon']) // 30
        state.append(head_pose)

        # object state query
        # if self.keep_history
        frame_info = FrameInfo(self.event)
        obj_query_state = frame_info.get_answer_array_for_all_candidates(
            self.query_indicates)
        state.extend(obj_query_state)

        return state

    def take_action(self):
        current_state = self.observation

        current_state_tensor = torch.FloatTensor(current_state).unsqueeze(0)
        if self.use_gpu:
            current_state_tensor = current_state_tensor.to("cuda")
        action, log_prob = self.policy.sample_action_with_prob(
            current_state_tensor)

        action_code = torch.argmax(action, dim=-1)[0].item()
        #print(action_code)

        self.step(action_code)
        next_observation = self.get_observation()

        # calulate reward
        reward = self.time_penalty

        if not self.event.metadata["lastActionSuccess"]:
            reward += self.action_fail_penalty

        frame_info = FrameInfo(self.event)
        if not self.first_seen:
            for obj in frame_info.object_info:
                if self.target_type == obj["objectType"]:
                    reward += self.first_seen_reward
                    self.first_seen = True
                    break

        if not self.first_in_range:
            for obj in frame_info.object_info:
                if self.target_type == obj["objectType"] and obj[
                        "visible"] == True:
                    reward += self.first_in_range_reward
                    self.first_in_range = True
                    break

        for obj in frame_info.object_info:
            if self.target_type == obj["objectType"] and obj["visible"] == True:
                if action_code == 5:
                    reward += self.mission_success_reward
                    self.episode_done = True
                break

        self.episode_history.append([
            self.observation.copy(), action_code, reward, log_probs[0],
            self.episode_done
        ])

        self.observation = next_observation

        # for print record
        self.episode_steps += 1
        self.episode_total_reward += reward

    def learn(self):
        '''
        On episode ends, learn somethings
        '''

        # A2C part
        policy_loss = 0
        value_loss = 0

        R = 0
        for i in reversed(range(len(self.episode_history))):
            state = h[0]
            action = h[1]
            reward = h[2]
            log_probs = h[3]

            h = self.episode_history[i]
            R = self.gamma * R + reward

            s0 = torch.FloatTensor(state)
            if self.use_gpu:
                s0 = s0.to("cuda")

            value_i = self.value_net(s0.unsqueeze(-1))[0]
            advantage = R - values_i
            value_loss += 0.5 * advantage**2

            entropy = -torch.sum(torch.exp(log_probs) * log_probs)

            policy_loss = policy_loss - log_probs[action] * advantage.detach(
            ) - self.alpha * entropy

        #update parameters
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

    def reset_scene_and_target(self, scene: str, target: str):
        self.controller.reset(scene=scene)
        self.target_type = target

    def reset_episode(self):
        self.reset_scene_and_target(self.scene, self.target)
        self.episode_done = False
        self.episode_steps = 0
        self.episode_total_reward = 0
        self.episode_policy_loss.clear()
        self.episode_pf1_loss.clear()
        self.episode_pf2_loss.clear()

    def close(self):
        self.controller.stop()

    def save_model(self):
        from datetime import datetime
        now = datetime.now()
        time_str = now.strftime("%H:%M:%S")
        torch.save(self.qf1.state_dict(), "record/qf1_" + time_str + ".pth")
        torch.save(self.qf2.state_dict(), "record/qf2_" + time_str + ".pth")
        torch.save(self.qf2.state_dict(), "record/policy_" + time_str + ".pth")
Пример #9
0
class drone_explorer(Thread):
    def __init__(self, pid, cfg, queue, mode='train'):
        super(drone_explorer, self).__init__()

        ### thread setting
        self.pid = pid
        set_random_seed(pid + cfg.framework.seed)
        self.cfg = cfg
        self.queue = queue
        self.iters = 0
        self.reset = False
        self.kill = False

        ### data
        self.objects = json.load(open(cfg.object_dir))
        if mode == 'train':
            self.trajectory = json.load(open(cfg.train.meta))
            self.object_random = cfg.train.random_object
        elif mode == 'val':
            self.trajectory = json.load(open(cfg.val.meta))
            self.object_random = cfg.val.random_object
        elif mode == 'test':
            self.trajectory = json.load(open(cfg.test.meta))
            self.object_random = cfg.test.random_object
        else:
            raise NotImplementedError

        self.trajectory_idx = np.arange(len(self.trajectory['scene']))
        np.random.shuffle(self.trajectory_idx)

        if pid == 0:
            for epoch in range(self.cfg.train.num_epoch):
                for ii in range(len(self.trajectory_idx)):
                    self.queue.put(self.trajectory_idx[ii])
        self.force, self.angle_x, self.angle_y = 0, 0, 0
        self.event = None

        ### THOR setting
        self.port = cfg.thor.port + pid
        self.x_display = '0.{}'.format(pid % cfg.framework.num_gpu + cfg.thor.x_display_offset)
        self.restart_cond = cfg.thor.restart_cond
        self.controller = None

        ### try it
        self.restart()

    def restart(self):
        ### reset the unity to avoid some latency issue ...
        if isinstance(self.controller, type(None)):
            self.controller = Controller(local_executable_path=PATH,
                                         scene="FloorPlan201_physics",
                                         x_display=self.x_display,
                                         agentMode='drone',
                                         fieldOfView=60,
                                         port=self.port)
        else:
            self.controller.stop()
            self.controller = Controller(local_executable_path=PATH,
                                         scene="FloorPlan201_physics",
                                         x_display=self.x_display,
                                         agentMode='drone',
                                         fieldOfView=60,
                                         port=self.port)
        _ = self.controller.reset(self.trajectory['scene'][0])
        _ = self.controller.step(dict(action='ChangeAutoResetTimeScale', timeScale=self.cfg.thor.time_scale))
        _ = self.controller.step(dict(action='ChangeFixedDeltaTime', fixedDeltaTime=self.cfg.thor.delta_time))
        print('Thread:'+str(self.pid)+' ['+str(self.iters+1)+']: restart finish')

    def run(self):
        while not self.kill:
            ### check if it needs to reset or the agent is still exploring
            if self.reset and not self.queue.empty():
                ### check if it reaches the restart condition
                if self.iters % self.restart_cond == 0 and self.iters > 1:
                    self.restart()

                ### get meta data of the trajectory
                try:
                    tidx = self.queue.get(timeout=1)
                except Empty:
                    if self.kill:
                        break
                    else:
                        self.good_to_go = False
                        self.reset = False
                        continue

                ### have a loop to ensure the trajectory has a good start
                self.good_to_go = False
                while not self.good_to_go:
                    scene = self.trajectory['scene'][tidx]
                    object_name = self.trajectory['object'][tidx]
                    mass = self.objects[object_name][1]
                    drone_position = self.trajectory['drone_position'][tidx]
                    launcher_position = self.trajectory['launcher_position'][tidx]
                    force = self.trajectory['force'][tidx]
                    angle_y = self.trajectory['angle_y'][tidx]
                    angle_x = self.trajectory['angle_x'][tidx]
                    if object_name == "Glassbottle":
                        object_name = "Bottle"

                    # make sure the value is .2f
                    drone_position['x'] = np.round(drone_position['x'], 2)
                    drone_position['y'] = np.round(1.5, 2)
                    drone_position['z'] = np.round(drone_position['z'], 2)
                    launcher_position['x'] = np.round(launcher_position['x'], 2)
                    launcher_position['y'] = np.round(launcher_position['y'], 2)
                    launcher_position['z'] = np.round(launcher_position['z'], 2)
                    force = np.round(force, 2)
                    angle_y = np.round(angle_y, 2)
                    angle_x = np.round(angle_x, 2)

                    ### set THOR
                    event = self.controller.reset(scene)
                    event = self.controller.step(dict(action='SpawnDroneLauncher', position=launcher_position))
                    event = self.controller.step(dict(action='FlyAssignStart',
                                                      position=drone_position,
                                                      x=launcher_position['x'],
                                                      y=launcher_position['y'],
                                                      z=launcher_position['z']))
                    event = self.controller.step(dict(action='Rotate', rotation=dict(x=0, y=0, z=0)))
                    event = self.controller.step(dict(action='ChangeAutoResetTimeScale',
                                                      timeScale=self.cfg.thor.time_scale))
                    event = self.controller.step(dict(action='ChangeFixedDeltaTime',
                                                      fixedDeltaTime=self.cfg.thor.delta_time))
                    if "noise_sigma" in self.cfg.agent:
                        event = self.controller.step(dict(action='ChangeDronePositionRandomNoiseSigma',
                                                          dronePositionRandomNoiseSigma=self.cfg.agent.noise_sigma))

                    ### prepare to launch the object
                    position = event.metadata['agent']['position']
                    event = self.controller.step(dict(action = 'LaunchDroneObject',
                                                      moveMagnitude = force,
                                                      x = angle_x,
                                                      y = angle_y,
                                                      z = -1,
                                                      objectName=object_name,
                                                      objectRandom=self.object_random))

                    if np.round(event.metadata['currentTime'], 2) == 0.00:
                        event = self.controller.step(dict(action='Pass'))

                    if np.round(event.metadata['currentTime'], 2) == 0.02:
                        self.good_to_go = True

                ### some record
                self.angle_x = angle_x
                self.angle_y = angle_y
                self.force = force
                self.object_name = object_name
                self.mass = mass
                self.event = event
                self.tidx = tidx

                ### thread setting
                self.reset = False
                self.iters += 1

            ### otherwise, just go to sleep
            else:
                time.sleep(0.1)
class Ai2Thor():
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True
        self.do_orbslam = False
        self.do_depth_noise = False
        self.makevideo = True
        # st()

        # these are all map names
        a = np.arange(1, 30)
        b = np.arange(201, 231)
        c = np.arange(301, 331)
        d = np.arange(401, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in list(abcd):
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        np.random.seed(1)
        random.shuffle(mapnames)
        self.mapnames = mapnames
        self.num_episodes = len(self.mapnames)

        self.ignore_classes = []
        # classes to save
        self.include_classes = [
            'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel',
            'HandTowel', 'TowelHolder', 'SoapBar', 'ToiletPaper',
            'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan',
            'Candle', 'ScrubBrush', 'Plunger', 'SinkBasin', 'Cloth',
            'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed',
            'Book', 'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil',
            'CellPhone', 'KeyChain', 'Painting', 'CreditCard', 'AlarmClock',
            'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk',
            'Curtains', 'Dresser', 'Watch', 'Television', 'WateringCan',
            'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
            'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat',
            'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit', 'Shelf',
            'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave',
            'CoffeeMachine', 'Fork', 'Fridge', 'WineBottle', 'Spatula',
            'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato',
            'PepperShaker', 'ButterKnife', 'StoveKnob', 'Toaster',
            'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
            'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster',
            'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
            'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll',
            'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
            'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window'
        ]

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25

        self.obj_per_scene = 5

        # self.origin_quaternion = np.quaternion(1, 0, 0, 0)
        # self.origin_rot_vector = quaternion.as_rotation_vector(self.origin_quaternion)

        self.homepath = f'/home/sirdome/katefgroup/gsarch/ithor/data/test'
        # self.basepath = '/home/nel/gsarch/replica_traj_bed'
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.W = 256
        self.H = 256

        self.fov = 90
        hfov = float(self.fov) * np.pi / 180.
        self.pix_T_camX = np.array([[
            (self.W / 2.) * 1 / np.tan(hfov / 2.), 0., 0., 0.
        ], [0., (self.H / 2.) * 1 / np.tan(hfov / 2.), 0., 0.], [0., 0., 1, 0],
                                    [0., 0., 0, 1]])
        self.pix_T_camX[0, 2] = self.W / 2.
        self.pix_T_camX[1, 2] = self.H / 2.

        self.run_episodes()

    def run_episodes(self):
        self.ep_idx = 0
        # self.objects = []

        for episode in range(self.num_episodes):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames[episode]
            print("MAPNAME=", mapname)

            self.controller = Controller(
                scene=mapname,
                gridSize=0.25,
                width=self.W,
                height=self.H,
                fieldOfView=self.fov,
                renderObjectImage=True,
                renderDepthImage=True,
            )
            # self.controller.start()

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run()

            self.controller.stop()
            time.sleep(1)

            self.ep_idx += 1

    def save_datapoint(self, observations, data_path, viewnum, flat_view):
        if self.verbose:
            print("Print Sensor States.", self.agent.state.sensor_states)
        rgb = observations["color_sensor"]
        semantic = observations["semantic_sensor"]
        # st()
        depth = observations["depth_sensor"]
        agent_pos = observations["positions"]
        agent_rot = observations["rotations"]
        # Assuming all sensors have same extrinsics
        color_sensor_pos = observations["positions"]
        color_sensor_rot = observations["rotations"]
        #print("POS ", agent_pos)
        #print("ROT ", color_sensor_rot)
        object_list = observations['object_list']

        # print(viewnum, agent_pos)
        # print(agent_rot)

        if False:
            plt.imshow(rgb)
            plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{viewnum}.png'
            plt.savefig(plt_name)

        save_data = {
            'flat_view': flat_view,
            'objects_info': object_list,
            'rgb_camX': rgb,
            'depth_camX': depth,
            'sensor_pos': color_sensor_pos,
            'sensor_rot': color_sensor_rot
        }

        with open(os.path.join(data_path, str(viewnum) + ".p"), 'wb') as f:
            pickle.dump(save_data, f)
        f.close()

    def quat_from_angle_axis(self, theta: float,
                             axis: np.ndarray) -> np.quaternion:
        r"""Creates a quaternion from angle axis format

        :param theta: The angle to rotate about the axis by
        :param axis: The axis to rotate about
        :return: The quaternion
        """
        axis = axis.astype(np.float)
        axis /= np.linalg.norm(axis)
        return quaternion.from_rotation_vector(theta * axis)

    def safe_inverse_single(self, a):
        r, t = self.split_rt_single(a)
        t = np.reshape(t, (3, 1))
        r_transpose = r.T
        inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
        bottom_row = a[3:4, :]  # this is [0, 0, 0, 1]
        # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)
        inv = np.concatenate([inv, bottom_row], 0)
        return inv

    def split_rt_single(self, rt):
        r = rt[:3, :3]
        t = np.reshape(rt[:3, 3], 3)
        return r, t

    def eul2rotm(self, rx, ry, rz):
        # inputs are shaped B
        # this func is copied from matlab
        # R = [  cy*cz   sy*sx*cz-sz*cx    sy*cx*cz+sz*sx
        #        cy*sz   sy*sx*sz+cz*cx    sy*cx*sz-cz*sx
        #        -sy            cy*sx             cy*cx]
        # rx = np.expand_dims(rx, axis=1)
        # ry = np.expand_dims(ry, axis=1)
        # rz = np.expand_dims(rz, axis=1)
        # st()
        # these are B x 1
        sinz = np.sin(rz)
        siny = np.sin(ry)
        sinx = np.sin(rx)
        cosz = np.cos(rz)
        cosy = np.cos(ry)
        cosx = np.cos(rx)
        r11 = cosy * cosz
        r12 = sinx * siny * cosz - cosx * sinz
        r13 = cosx * siny * cosz + sinx * sinz
        r21 = cosy * sinz
        r22 = sinx * siny * sinz + cosx * cosz
        r23 = cosx * siny * sinz - sinx * cosz
        r31 = -siny
        r32 = sinx * cosy
        r33 = cosx * cosy

        r = np.array([[r11, r12, r13], [r21, r22, r23], [r31, r32, r33]])
        # r1 = np.stack([r11,r12,r13],axis=2)
        # r2 = np.stack([r21,r22,r23],axis=2)
        # r3 = np.stack([r31,r32,r33],axis=2)
        # r = np.concatenate([r1,r2,r3],axis=1)
        return r

    def rotm2eul(self, r):
        # r is Bx3x3, or Bx4x4
        r00 = r[0, 0]
        r10 = r[1, 0]
        r11 = r[1, 1]
        r12 = r[1, 2]
        r20 = r[2, 0]
        r21 = r[2, 1]
        r22 = r[2, 2]

        ## python guide:
        # if sy > 1e-6: # singular
        #     x = math.atan2(R[2,1] , R[2,2])
        #     y = math.atan2(-R[2,0], sy)
        #     z = math.atan2(R[1,0], R[0,0])
        # else:
        #     x = math.atan2(-R[1,2], R[1,1])
        #     y = math.atan2(-R[2,0], sy)
        #     z = 0

        sy = np.sqrt(r00 * r00 + r10 * r10)

        cond = (sy > 1e-6)
        rx = np.where(cond, np.arctan2(r21, r22), np.arctan2(-r12, r11))
        ry = np.where(cond, np.arctan2(-r20, sy), np.arctan2(-r20, sy))
        rz = np.where(cond, np.arctan2(r10, r00), np.zeros_like(r20))

        # rx = torch.atan2(r21, r22)
        # ry = torch.atan2(-r20, sy)
        # rz = torch.atan2(r10, r00)
        # rx[cond] = torch.atan2(-r12, r11)
        # ry[cond] = torch.atan2(-r20, sy)
        # rz[cond] = 0.0
        return rx, ry, rz

    def generate_xyz_habitatCamXs(self, flat_obs):

        pix_T_camX = self.pix_T_camX
        xyz_camXs = []
        for i in range(2):  #depth_camXs.shape[0]):
            K = pix_T_camX
            xs, ys = np.meshgrid(
                np.linspace(-1 * self.W / 2., 1 * self.W / 2., self.W),
                np.linspace(1 * self.W / 2., -1 * self.W / 2., self.W))
            depth = flat_obs[i]['depth_sensor'].reshape(1, self.W, self.W)
            if i > 0:
                rotation_X = flat_obs[i][
                    'rotations']  #quaternion.as_rotation_matrix(rot)
                pos = flat_obs[i]['positions']
                euler_rot_X = flat_obs[i]["rotations_euler"]
                # euler_rot_X_rad = np.radians(flat_obs[i]["rotations_euler"])
                origin_T_camX_4x4 = np.eye(4)
                origin_T_camX_4x4[0:3, 0:3] = rotation_X
                origin_T_camX_4x4[0:3, 3] = pos
                origin_T_camX = rotation_X
                camX_T_origin_4x4 = self.safe_inverse_single(origin_T_camX_4x4)
                camX0_T_camX = np.matmul(camX0_T_origin, origin_T_camX)
                camX0_T_camX_4x4 = np.matmul(camX0_T_origin_4x4,
                                             origin_T_camX_4x4)
                # camX0_T_camX = np.matmul(rotation_0, np.linalg.inv(origin_T_camX))
                # camX0_T_camX_4x4 = np.matmul(origin_T_camX0_4x4, camX_T_origin_4x4)
                # r = R.from_matrix(camX0_T_camX)
                # print("CamX0_T_camX CHECK: ", r.as_euler('xyz', degrees=True))
                # r = R.from_matrix(origin_T_camX_4x4[0:3,0:3])
                # print("CamX0_T_camX CHECK: ", r.as_euler('xyz', degrees=True))
                # r = R.from_matrix(rotation_X)
                # rx = euler_rot_X_rad[0]
                # ry = euler_rot_X_rad[1]
                # rz = euler_rot_X_rad[2]
                # rotm = self.eul2rotm(rx, ry, rz)
                rx, ry, rz = self.rotm2eul(rotation_X)
                print("EULER ACTUAL: ", euler_rot_X)
                print("EULER OBTAINED: ", rx, ry, rz)

                rx, ry, rz = self.rotm2eul(camX0_T_camX)
                # print("origin_T_camx CHECK: ", r.as_euler('xyz', degrees=True))s
                print("EULER SUBTRACT: ", euler_rot_0 - euler_rot_X)
                print("EULER OBTAINED: ", rx, ry, rz)
                # st()
            elif i == 0:
                rotation_0 = flat_obs[i][
                    'rotations']  #quaternion.as_rotation_matrix(rot)
                pos = flat_obs[i]['positions']
                euler_rot_0 = flat_obs[i]["rotations_euler"]
                origin_T_camX0_4x4 = np.eye(4)
                origin_T_camX0_4x4[0:3, 0:3] = rotation_0
                origin_T_camX0_4x4[0:3, 3] = pos
                camX0_T_origin_4x4 = self.safe_inverse_single(
                    origin_T_camX0_4x4)
                camX0_T_origin = np.linalg.inv(rotation_0)

            xs = xs.reshape(1, self.W, self.W)
            ys = ys.reshape(1, self.W, self.W)

            xys = np.vstack(
                (xs * depth, ys * depth, -depth, np.ones(depth.shape)))
            xys = xys.reshape(4, -1)
            xy_c0 = np.matmul(np.linalg.inv(K), xys)
            xyz_camX = xy_c0.T[:, :3]
            xyz_camXs.append(xyz_camX)
            if i == 1:
                xyz = np.expand_dims(xyz_camX, axis=0)
                B, N, _ = list(xyz.shape)
                ones = np.ones_like(xyz[:, :, 0:1])
                xyz1 = np.concatenate([xyz, ones], 2)
                xyz1_t = np.transpose(xyz1, (0, 2, 1))
                # this is B x 4 x N
                xyz2_t = np.matmul(camX0_T_camX_4x4, xyz1_t)
                xyz2 = np.transpose(xyz2_t, (0, 2, 1))
                xyz2 = np.squeeze(xyz2[:, :, :3])
                # xyz2 = xyz_camX

        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        xs = xyz2[:, 0]
        ys = xyz2[:, 1]
        zs = xyz2[:, 2]
        # ax.scatter(xs, ys, zs)
        plt.plot(xs, zs, 'x', color='red')
        xyz3 = xyz_camXs[0]
        xs = xyz3[:, 0]
        ys = xyz3[:, 1]
        zs = xyz3[:, 2]
        # ax.scatter(xs, ys, zs)
        plt.plot(xs, zs, 'o', color='blue')
        plt_name = '/home/nel/gsarch/aithor/data/pointcloud.png'
        plt.savefig(plt_name)
        st()

        return np.stack(xyz_camXs)

    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    def run(self):
        event = self.controller.step('GetReachablePositions')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        np.random.seed(1)
        # objects = np.random.choice(event.metadata['objects'], self.obj_per_scene, replace=False)
        objects = event.metadata['objects']
        objects_inds = np.arange(len(event.metadata['objects']))
        np.random.shuffle(objects_inds)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        successes = 0
        meta_obj_idx = 0
        while successes < self.obj_per_scene and meta_obj_idx <= len(
                event.metadata['objects']
        ) - 1:  #obj in objects: #event.metadata['objects']: #objects:
            obj = objects[objects_inds[meta_obj_idx]]
            meta_obj_idx += 1
            print("Object is ", obj['objectType'])
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue

            # Calculate distance to object center
            obj_center = np.array(
                list(obj['axisAlignedBoundingBox']['center'].values()))

            obj_center = np.expand_dims(obj_center, axis=0)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where(
                (distances > self.radius_min) * (distances < self.radius_max))]

            # Bin points based on angles [vertical_angle (10 deg/bin), horizontal_angle (10 deg/bin)]
            valid_pts_shift = valid_pts - obj_center

            dz = valid_pts_shift[:, 2]
            dx = valid_pts_shift[:, 0]
            dy = valid_pts_shift[:, 1]

            # Get yaw for binning
            valid_yaw = np.degrees(np.arctan2(dz, dx))

            nbins = 18
            bins = np.linspace(-180, 180, nbins + 1)
            bin_yaw = np.digitize(valid_yaw, bins)

            num_valid_bins = np.unique(bin_yaw).size

            if False:
                import matplotlib.cm as cm
                colors = iter(cm.rainbow(np.linspace(0, 1, nbins)))
                plt.figure(2)
                plt.clf()
                print(np.unique(bin_yaw))
                for bi in range(nbins):
                    cur_bi = np.where(bin_yaw == (bi + 1))
                    points = valid_pts[cur_bi]
                    x_sample = points[:, 0]
                    z_sample = points[:, 2]
                    plt.plot(z_sample, x_sample, 'o', color=next(colors))
                plt.plot(self.nav_pts[:, 2],
                         self.nav_pts[:, 0],
                         'x',
                         color='red')
                plt.plot(obj_center[:, 2],
                         obj_center[:, 0],
                         'x',
                         color='black')
                plt_name = '/home/nel/gsarch/aithor/data/valid.png'
                plt.savefig(plt_name)

            if num_valid_bins == 0:
                continue

            spawns_per_bin = int(self.num_views / num_valid_bins) + 2
            # print(f'spawns_per_bin: {spawns_per_bin}')

            action = "do_nothing"
            episodes = []
            valid_pts_selected = []
            camXs_T_camX0_4x4 = []
            camX0_T_camXs_4x4 = []
            origin_T_camXs = []
            origin_T_camXs_t = []
            cnt = 0
            for b in range(nbins):

                # get all angle indices in the current bin range
                inds_bin_cur = np.where(
                    bin_yaw == (b + 1))  # bins start 1 so need +1
                inds_bin_cur = list(inds_bin_cur[0])
                if len(inds_bin_cur) == 0:
                    continue

                for s in range(spawns_per_bin):
                    observations = {}

                    if len(inds_bin_cur) == 0:
                        continue

                    rand_ind = np.random.randint(0, len(inds_bin_cur))
                    s_ind = inds_bin_cur.pop(rand_ind)

                    pos_s = valid_pts[s_ind]
                    valid_pts_selected.append(pos_s)

                    # add height from center of agent to camera
                    pos_s[1] = pos_s[1] + 0.675

                    # YAW calculation - rotate to object
                    agent_to_obj = np.squeeze(obj_center) - pos_s
                    agent_local_forward = np.array([0, 0, 1.0])
                    flat_to_obj = np.array(
                        [agent_to_obj[0], 0.0, agent_to_obj[2]])
                    flat_dist_to_obj = np.linalg.norm(flat_to_obj)
                    flat_to_obj /= flat_dist_to_obj

                    det = (flat_to_obj[0] * agent_local_forward[2] -
                           agent_local_forward[0] * flat_to_obj[2])
                    turn_angle = math.atan2(
                        det, np.dot(agent_local_forward, flat_to_obj))

                    turn_yaw = np.degrees(turn_angle)

                    turn_pitch = -np.degrees(
                        math.atan2(agent_to_obj[1], flat_dist_to_obj))

                    event = self.controller.step('TeleportFull',
                                                 x=pos_s[0],
                                                 y=pos_s[1],
                                                 z=pos_s[2],
                                                 rotation=dict(x=0.0,
                                                               y=int(turn_yaw),
                                                               z=0.0),
                                                 horizon=int(turn_pitch))

                    rgb = event.frame

                    eulers_xyz_rad = np.radians(
                        np.array([
                            event.metadata['agent']['cameraHorizon'],
                            event.metadata['agent']['rotation']['y'], 0.0
                        ]))

                    rx = eulers_xyz_rad[0]
                    ry = eulers_xyz_rad[1]
                    rz = eulers_xyz_rad[2]
                    rotation_r_matrix = self.eul2rotm(-rx, -ry, rz)

                    agent_position = np.array(
                        list(event.metadata['agent']['position'].values())
                    ) + np.array([0.0, 0.675, 0.0])
                    # need to invert since z is positive here by convention
                    agent_position[2] = -agent_position[2]

                    observations["positions"] = agent_position

                    observations["rotations"] = rotation_r_matrix

                    # rt_4x4 = np.eye(4)
                    # rt_4x4[0:3,0:3] = observations["rotations"]
                    # rt_4x4[0:3,3] = observations["positions"]
                    # rt_4x4_inv = self.safe_inverse_single(rt_4x4)
                    # r, t = self.split_rt_single(rt_4x4_inv)

                    # observations["positions"] = r

                    # observations["positions"] = t

                    # observations["rotations_euler"] = np.array([rx, ry, rz]) #rotation_r.as_euler('xyz', degrees=True)

                    observations["color_sensor"] = rgb
                    observations["depth_sensor"] = event.depth_frame
                    observations[
                        "semantic_sensor"] = event.instance_segmentation_frame

                    if False:
                        plt.imshow(rgb)
                        plt_name = f'/home/nel/gsarch/aithor/data/test/img_true{s}{b}.png'
                        plt.savefig(plt_name)

                    # print("Processed image #", cnt, " for object ", obj['objectType'])

                    semantic = event.instance_segmentation_frame
                    object_id_to_color = event.object_id_to_color
                    color_to_object_id = event.color_to_object_id

                    obj_ids = np.unique(semantic.reshape(
                        -1, semantic.shape[2]),
                                        axis=0)

                    instance_masks = event.instance_masks
                    instance_detections2D = event.instance_detections2D

                    obj_metadata_IDs = []
                    for obj_m in event.metadata['objects']:  #objects:
                        obj_metadata_IDs.append(obj_m['objectId'])

                    object_list = []
                    for obj_idx in range(obj_ids.shape[0]):
                        try:
                            obj_color = tuple(obj_ids[obj_idx])
                            object_id = color_to_object_id[obj_color]
                        except:
                            # print("Skipping ", object_id)
                            continue

                        if object_id not in obj_metadata_IDs:
                            # print("Skipping ", object_id)
                            continue

                        obj_meta_index = obj_metadata_IDs.index(object_id)
                        obj_meta = event.metadata['objects'][obj_meta_index]
                        obj_category_name = obj_meta['objectType']

                        # continue if not visible or not in include classes
                        if obj_category_name not in self.include_classes or not obj_meta[
                                'visible']:
                            continue

                        obj_instance_mask = instance_masks[object_id]
                        obj_instance_detection2D = instance_detections2D[
                            object_id]  # [start_x, start_y, end_x, end_y]
                        obj_instance_detection2D = np.array([
                            obj_instance_detection2D[1],
                            obj_instance_detection2D[0],
                            obj_instance_detection2D[3],
                            obj_instance_detection2D[2]
                        ])  # ymin, xmin, ymax, xmax

                        if False:
                            print(object_id)
                            plt.imshow(obj_instance_mask)
                            plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{s}.png'
                            plt.savefig(plt_name)

                        obj_center_axisAligned = np.array(
                            list(obj_meta['axisAlignedBoundingBox']
                                 ['center'].values()))
                        obj_center_axisAligned[2] = -obj_center_axisAligned[2]
                        obj_size_axisAligned = np.array(
                            list(obj_meta['axisAlignedBoundingBox']
                                 ['size'].values()))

                        # print(obj_category_name)

                        if self.verbose:
                            print("Saved class name is : ", obj_category_name)

                        obj_data = {
                            'instance_id': object_id,
                            'category_id': object_id,
                            'category_name': obj_category_name,
                            'bbox_center': obj_center_axisAligned,
                            'bbox_size': obj_size_axisAligned,
                            'mask_2d': obj_instance_mask,
                            'box_2d': obj_instance_detection2D
                        }
                        # object_list.append(obj_instance)
                        object_list.append(obj_data)

                    observations["object_list"] = object_list

                    # check if object visible (make sure agent is not behind a wall)
                    obj_id = obj['objectId']
                    obj_id_to_color = object_id_to_color[obj_id]
                    # if np.sum(obj_ids==object_id_to_color[obj_id]) > 0:
                    if self.verbose:
                        print("episode is valid......")
                    episodes.append(observations)
                    # print(cnt)

                    if False:
                        if cnt > 0:

                            origin_T_camX = episodes[cnt]["rotations"]

                            r = R.from_matrix(origin_T_camX)
                            print("EULER CHECK: ",
                                  r.as_euler('xyz', degrees=True))

                            camX0_T_camX = np.matmul(camX0_T_origin,
                                                     origin_T_camX)
                            r = R.from_matrix(camX0_T_camX)
                            print("EULER CHECK: ",
                                  r.as_euler('xyz', degrees=True))
                            # camX0_T_camX = np.matmul(camX0_T_origin, origin_T_camX)
                            r = R.from_matrix(camX0_T_origin)
                            print("EULER CHECK: ",
                                  r.as_euler('xyz', degrees=True))

                            origin_T_camXs.append(origin_T_camX)
                            origin_T_camXs_t.append(episodes[cnt]["positions"])

                            origin_T_camX_4x4 = np.eye(4)
                            origin_T_camX_4x4[0:3, 0:3] = origin_T_camX
                            origin_T_camX_4x4[:3,
                                              3] = episodes[cnt]["positions"]
                            camX0_T_camX_4x4 = np.matmul(
                                camX0_T_origin_4x4, origin_T_camX_4x4)
                            camX_T_camX0_4x4 = self.safe_inverse_single(
                                camX0_T_camX_4x4)

                            camXs_T_camX0_4x4.append(camX_T_camX0_4x4)
                            camX0_T_camXs_4x4.append(camX0_T_camX_4x4)

                            camX0_T_camX_quat = quaternion.from_rotation_matrix(
                                camX0_T_camX)
                            camX0_T_camX_eul = quaternion.as_euler_angles(
                                camX0_T_camX_quat)

                            camX0_T_camX_4x4 = self.safe_inverse_single(
                                camX_T_camX0_4x4)
                            origin_T_camX_4x4 = np.matmul(
                                origin_T_camX0_4x4, camX0_T_camX_4x4)
                            r_origin_T_camX, t_origin_T_camX, = self.split_rt_single(
                                origin_T_camX_4x4)

                            if self.verbose:
                                print(r_origin_T_camX)
                                print(origin_T_camX)

                        else:
                            origin_T_camX0 = episodes[0]["rotations"]
                            camX0_T_origin = np.linalg.inv(origin_T_camX0)
                            # camX0_T_origin = self.safe_inverse_single(origin_T_camX0)

                            origin_T_camXs.append(origin_T_camX0)
                            origin_T_camXs_t.append(episodes[0]["positions"])

                            origin_T_camX0_4x4 = np.eye(4)
                            origin_T_camX0_4x4[0:3, 0:3] = origin_T_camX0
                            origin_T_camX0_4x4[:3,
                                               3] = episodes[0]["positions"]
                            camX0_T_origin_4x4 = self.safe_inverse_single(
                                origin_T_camX0_4x4)

                            camXs_T_camX0_4x4.append(np.eye(4))

                            camX0_T_camXs_4x4.append(np.eye(4))

                            origin_T_camX0_t = episodes[0]["positions"]

                    cnt += 1

            if len(episodes) >= self.num_views:
                print(f'num episodes: {len(episodes)}')
                data_folder = obj['name']
                data_path = os.path.join(self.basepath, data_folder)
                print("Saving to ", data_path)
                os.mkdir(data_path)
                # flat_obs = np.random.choice(episodes, self.num_views, replace=False)
                np.random.seed(1)
                rand_inds = np.sort(
                    np.random.choice(len(episodes),
                                     self.num_views,
                                     replace=False))
                bool_inds = np.zeros(len(episodes), dtype=bool)
                bool_inds[rand_inds] = True
                flat_obs = np.array(episodes)[bool_inds]
                flat_obs = list(flat_obs)
                viewnum = 0
                if False:
                    self.generate_xyz_habitatCamXs(flat_obs)
                for obs in flat_obs:
                    self.save_datapoint(obs, data_path, viewnum, True)
                    viewnum += 1
                print("SUCCESS #", successes)
                successes += 1
            else:
                print("Not enough episodes:", len(episodes))
Пример #11
0
class RoboThorEnvironment:
    """Wrapper for the robo2thor controller providing additional functionality
    and bookkeeping.

    See [here](https://ai2thor.allenai.org/robothor/documentation) for comprehensive
     documentation on RoboTHOR.

    # Attributes

    controller : The AI2THOR controller.
    config : The AI2THOR controller configuration
    """
    def __init__(self, **kwargs):
        self.config = dict(
            rotateStepDegrees=30.0,
            visibilityDistance=1.0,
            gridSize=0.25,
            agentType="stochastic",
            continuousMode=True,
            snapToGrid=False,
            agentMode="bot",
            width=640,
            height=480,
        )
        recursive_update(self.config, {**kwargs, "agentMode": "bot"})
        self.controller = Controller(
            **self.config, server_class=ai2thor.fifo_server.FifoServer)
        self.known_good_locations: Dict[str, Any] = {
            self.scene_name: copy.deepcopy(self.currently_reachable_points)
        }
        self.distance_cache = DynamicDistanceCache(rounding=1)
        assert len(self.known_good_locations[self.scene_name]) > 10

    def initialize_grid_dimensions(
        self, reachable_points: Collection[Dict[str, float]]
    ) -> Tuple[int, int, int, int]:
        """Computes bounding box for reachable points quantized with the
        current gridSize."""
        points = {(
            round(p["x"] / self.config["gridSize"]),
            round(p["z"] / self.config["gridSize"]),
        ): p
                  for p in reachable_points}

        assert len(reachable_points) == len(points)

        xmin, xmax = min([p[0] for p in points]), max([p[0] for p in points])
        zmin, zmax = min([p[1] for p in points]), max([p[1] for p in points])

        return xmin, xmax, zmin, zmax

    def set_object_filter(self, object_ids: List[str]):
        self.controller.step("SetObjectFilter",
                             objectIds=object_ids,
                             renderImage=False)

    def reset_object_filter(self):
        self.controller.step("ResetObjectFilter", renderImage=False)

    def path_from_point_to_object_type(
            self, point: Dict[str, float],
            object_type: str) -> Optional[List[Dict[str, float]]]:
        try:
            return metrics.get_shortest_path_to_object_type(
                self.controller, object_type, point)
        except:
            get_logger().debug(
                "Failed to find path for {} in {}. Start point {}, agent state {}."
                .format(
                    object_type,
                    self.controller.last_event.metadata["sceneName"],
                    point,
                    self.agent_state(),
                ))
            return None

    def distance_from_point_to_object_type(self, point: Dict[str, float],
                                           object_type: str) -> float:
        """Minimal geodesic distance from a point to an object of the given
        type.

        It might return -1.0 for unreachable targets.
        """
        path = self.path_from_point_to_object_type(point, object_type)
        if path:
            return metrics.path_distance(path)
        return -1.0

    def distance_to_object_type(self, object_type: str) -> float:
        """Minimal geodesic distance to object of given type from agent's
        current location.

        It might return -1.0 for unreachable targets.
        """
        return self.distance_cache.find_distance(
            self.controller.last_event.metadata["agent"]["position"],
            object_type,
            self.distance_from_point_to_object_type,
        )

    def path_from_point_to_point(
            self, position: Dict[str, float],
            target: Dict[str, float]) -> Optional[List[Dict[str, float]]]:
        try:
            return self.controller.step(
                action="GetShortestPathToPoint",
                position=position,
                x=target["x"],
                y=target["y"],
                z=target["z"],
                # renderImage=False
            ).metadata["actionReturn"]["corners"]
        except:
            get_logger().debug(
                "Failed to find path for {} in {}. Start point {}, agent state {}."
                .format(
                    target,
                    self.controller.last_event.metadata["sceneName"],
                    position,
                    self.agent_state(),
                ))
            return None

    def distance_from_point_to_point(self, position: Dict[str, float],
                                     target: Dict[str, float]) -> float:
        path = self.path_from_point_to_point(position, target)
        if path:
            return metrics.path_distance(path)
        return -1.0

    def distance_to_point(self, target: Dict[str, float]) -> float:
        """Minimal geodesic distance to end point from agent's current
        location.

        It might return -1.0 for unreachable targets.
        """
        return self.distance_cache.find_distance(
            self.controller.last_event.metadata["agent"]["position"],
            target,
            self.distance_from_point_to_point,
        )

    def agent_state(self) -> Dict:
        """Return agent position, rotation and horizon."""
        agent_meta = self.last_event.metadata["agent"]
        return {
            **{k: float(v)
               for k, v in agent_meta["position"].items()},
            "rotation":
            {k: float(v)
             for k, v in agent_meta["rotation"].items()},
            "horizon": round(float(agent_meta["cameraHorizon"]), 1),
        }

    def teleport(self,
                 pose: Dict[str, float],
                 rotation: Dict[str, float],
                 horizon: float = 0.0):
        e = self.controller.step(
            action="TeleportFull",
            x=pose["x"],
            y=pose["y"],
            z=pose["z"],
            rotation=rotation,
            horizon=horizon,
        )
        return e.metadata["lastActionSuccess"]

    def reset(self,
              scene_name: str = None,
              filtered_objects: Optional[List[str]] = None) -> None:
        """Resets scene to a known initial state."""
        if scene_name is not None and scene_name != self.scene_name:
            self.controller.reset(scene_name)
            assert self.last_action_success, "Could not reset to new scene"
            if scene_name not in self.known_good_locations:
                self.known_good_locations[scene_name] = copy.deepcopy(
                    self.currently_reachable_points)
                assert len(self.known_good_locations[scene_name]) > 10
        if filtered_objects:
            self.set_object_filter(filtered_objects)
        else:
            self.reset_object_filter()

    def random_reachable_state(
        self,
        seed: Optional[int] = None
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Returns a random reachable location in the scene."""
        if seed is not None:
            random.seed(seed)
        # xyz = random.choice(self.currently_reachable_points)
        assert len(self.known_good_locations[self.scene_name]) > 10
        xyz = copy.deepcopy(
            random.choice(self.known_good_locations[self.scene_name]))
        rotation = random.choice(
            np.arange(0.0, 360.0, self.config["rotateStepDegrees"]))
        horizon = 0.0  # random.choice([0.0, 30.0, 330.0])
        return {
            **{k: float(v)
               for k, v in xyz.items()},
            "rotation": {
                "x": 0.0,
                "y": float(rotation),
                "z": 0.0
            },
            "horizon": float(horizon),
        }

    def randomize_agent_location(
        self,
        seed: int = None,
        partial_position: Optional[Dict[str, float]] = None
    ) -> Dict[str, Union[Dict[str, float], float]]:
        """Teleports the agent to a random reachable location in the scene."""
        if partial_position is None:
            partial_position = {}
        k = 0
        state: Optional[Dict] = None

        while k == 0 or (not self.last_action_success and k < 10):
            # self.reset()
            state = {
                **self.random_reachable_state(seed=seed),
                **partial_position
            }
            # get_logger().debug("picked target location {}".format(state))
            self.controller.step("TeleportFull", **state)
            k += 1

        if not self.last_action_success:
            get_logger().warning((
                "Randomize agent location in scene {} and current random state {}"
                " with seed {} and partial position {} failed in "
                "10 attempts. Forcing the action.").format(
                    self.scene_name, state, seed, partial_position))
            self.controller.step("TeleportFull", **state,
                                 force_action=True)  # type: ignore
            assert self.last_action_success, "Force action failed with {}".format(
                state)

        # get_logger().debug("location after teleport full {}".format(self.agent_state()))
        # self.controller.step("TeleportFull", **self.agent_state())  # TODO only for debug
        # get_logger().debug("location after re-teleport full {}".format(self.agent_state()))

        return self.agent_state()

    def known_good_locations_list(self):
        return self.known_good_locations[self.scene_name]

    @property
    def currently_reachable_points(self) -> List[Dict[str, float]]:
        """List of {"x": x, "y": y, "z": z} locations in the scene that are
        currently reachable."""
        self.controller.step(action="GetReachablePositions")
        assert (
            self.last_action_success
        ), f"Could not get reachable positions for reason {self.last_event.metadata['errorMessage']}."
        return self.last_action_return

    @property
    def scene_name(self) -> str:
        """Current ai2thor scene."""
        return self.controller.last_event.metadata["sceneName"].replace(
            "_physics", "")

    @property
    def current_frame(self) -> np.ndarray:
        """Returns rgb image corresponding to the agent's egocentric view."""
        return self.controller.last_event.frame

    @property
    def current_depth(self) -> np.ndarray:
        """Returns depth image corresponding to the agent's egocentric view."""
        return self.controller.last_event.depth_frame

    @property
    def last_event(self) -> ai2thor.server.Event:
        """Last event returned by the controller."""
        return self.controller.last_event

    @property
    def last_action(self) -> str:
        """Last action, as a string, taken by the agent."""
        return self.controller.last_event.metadata["lastAction"]

    @property
    def last_action_success(self) -> bool:
        """Was the last action taken by the agent a success?"""
        return self.controller.last_event.metadata["lastActionSuccess"]

    @property
    def last_action_return(self) -> Any:
        """Get the value returned by the last action (if applicable).

        For an example of an action that returns a value, see
        `"GetReachablePositions"`.
        """
        return self.controller.last_event.metadata["actionReturn"]

    def step(self, action_dict: Dict) -> ai2thor.server.Event:
        """Take a step in the ai2thor environment."""
        return self.controller.step(**action_dict)

    def stop(self):
        """Stops the ai2thor controller."""
        try:
            self.controller.stop()
        except Exception as e:
            get_logger().warning(str(e))

    def all_objects(self) -> List[Dict[str, Any]]:
        """Return all object metadata."""
        return self.controller.last_event.metadata["objects"]

    def all_objects_with_properties(
            self, properties: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Find all objects with the given properties."""
        objects = []
        for o in self.all_objects():
            satisfies_all = True
            for k, v in properties.items():
                if o[k] != v:
                    satisfies_all = False
                    break
            if satisfies_all:
                objects.append(o)
        return objects

    def visible_objects(self) -> List[Dict[str, Any]]:
        """Return all visible objects."""
        return self.all_objects_with_properties({"visible": True})
Пример #12
0
class Ai2Thor():
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True

        self.plot_loss = True
        # st()

        # these are all map names
        a = np.arange(1, 30)
        b = np.arange(201, 231)
        c = np.arange(301, 331)
        d = np.arange(401, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in list(abcd):
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        train_len = int(0.9 * len(mapnames))

        np.random.seed(1)
        random.shuffle(mapnames)
        self.mapnames_train = mapnames[:train_len]
        self.mapnames_val = mapnames[train_len:]
        # self.num_episodes = len(self.mapnames)

        self.ignore_classes = []
        # classes to save
        self.include_classes = [
            'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel',
            'HandTowel', 'TowelHolder', 'SoapBar', 'ToiletPaper',
            'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan',
            'Candle', 'ScrubBrush', 'Plunger', 'SinkBasin', 'Cloth',
            'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed',
            'Book', 'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil',
            'CellPhone', 'KeyChain', 'Painting', 'CreditCard', 'AlarmClock',
            'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk',
            'Curtains', 'Dresser', 'Watch', 'Television', 'WateringCan',
            'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
            'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat',
            'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit', 'Shelf',
            'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave',
            'CoffeeMachine', 'Fork', 'Fridge', 'WineBottle', 'Spatula',
            'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato',
            'PepperShaker', 'ButterKnife', 'StoveKnob', 'Toaster',
            'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
            'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster',
            'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
            'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll',
            'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
            'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window'
        ]

        self.action_space = {
            0: "MoveLeft",
            1: "MoveRight",
            2: "MoveAhead",
            3: "MoveBack"
        }
        self.num_actions = len(self.action_space)

        cfg_det = get_cfg()
        cfg_det.merge_from_file(
            model_zoo.get_config_file(
                "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg_det.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2  # set threshold for this model
        cfg_det.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        cfg_det.MODEL.DEVICE = 'cpu'
        self.cfg_det = cfg_det
        self.maskrcnn = DefaultPredictor(cfg_det)

        self.conf_thresh_detect = 0.7  # for initially detecting a low confident object
        self.conf_thresh_init = 0.8  # for after turning head toward object threshold
        self.conf_thresh_end = 0.9  # if reach this then stop getting obs

        self.BATCH_SIZE = 12
        self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 10
        self.val_interval = 15
        self.save_interval = 50

        self.BATCH_SIZE = 1
        self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 1
        self.val_interval = 1
        self.save_interval = 1

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25
        self.center_from_mask = False  # get object centroid from maskrcnn (True) or gt (False)

        self.obj_per_scene = 5

        # self.origin_quaternion = np.quaternion(1, 0, 0, 0)
        # self.origin_rot_vector = quaternion.as_rotation_vector(self.origin_quaternion)

        # self.homepath = f'/home/nel/gsarch/aithor/data/test2'
        self.homepath = '/home/sirdome/katefgroup/gsarch/ithor/data/test'
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.log_freq = 1
        self.log_dir = self.homepath + '/..' + '/log_cem' + '/aa'
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        MAX_QUEUE = 10  # flushes when this amount waiting
        self.writer = SummaryWriter(self.log_dir,
                                    max_queue=MAX_QUEUE,
                                    flush_secs=60)

        self.W = 256
        self.H = 256

        # self.fov = 90
        # hfov = float(self.fov) * np.pi / 180.
        # self.pix_T_camX = np.array([
        #     [(self.W/2.)*1 / np.tan(hfov / 2.), 0., 0., 0.],
        #     [0., (self.H/2.)*1 / np.tan(hfov / 2.), 0., 0.],
        #     [0., 0.,  1, 0],
        #     [0., 0., 0, 1]])
        # self.pix_T_camX[0,2] = self.W/2.
        # self.pix_T_camX[1,2] = self.H/2.

        self.fov = 90
        self.camera_matrix = self.get_camera_matrix(self.W, self.H, self.fov)
        self.K = self.get_habitat_pix_T_camX(self.fov)

        self.init_network()

        self.run_episodes()

    def init_network(self):

        self.cemnet = CEMNet(h1=32,
                             h2=64,
                             fc_dim=1024,
                             num_actions=self.num_actions)

        self.loss = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(params=self.cemnet.parameters(),
                                          lr=0.01)

    def batch_iteration(self, mapnames, NN, BATCH_SIZE):

        batch = {"rewards": [], "obs": [], "actions": []}
        episode_rewards = 0.0
        iter_idx = 0
        while True:

            mapname = np.random.choice(mapnames)

            # self.basepath = self.homepath + f"/{mapname}_{episode}"
            # print("BASEPATH: ", self.basepath)

            # # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            # if not os.path.exists(self.basepath):
            #     os.mkdir(self.basepath)

            self.controller = Controller(
                scene=mapname,
                gridSize=0.25,
                width=self.W,
                height=self.H,
                fieldOfView=self.fov,
                renderObjectImage=True,
                renderDepthImage=True,
            )

            episode_rewards, obs, actions = self.run()

            print("Total reward for train batch # ", iter_idx, " :",
                  episode_rewards)

            self.controller.stop()
            time.sleep(1)

            if episode_rewards is None:
                print("NO EPISODE REWARDS.. SKIPPING BATCH INSTANCE")
                continue

            batch["rewards"].append(episode_rewards)
            batch["obs"].append(obs)
            batch["actions"].append(actions)

            iter_idx += 1
            print(len(batch["rewards"]))
            if len(batch["rewards"]) == BATCH_SIZE:

                yield batch
                iter_idx = 0
                batch = {"rewards": [], "obs": [], "actions": []}

            if len(batch["rewards"]) > BATCH_SIZE:
                st()

    def elite_batch(self, batch, percentile):

        rewards = np.array(batch["rewards"])
        obs = batch["obs"]
        actions = batch["actions"]

        rewards_mean = float(np.mean(rewards))

        rewards_boundary = np.percentile(rewards, percentile)

        print("Reward boundary: ", rewards_boundary)

        rewards_mean = float(np.mean(rewards))

        training_obs = []
        training_actions = []

        for idx in range(rewards.shape[0]):
            reward_idx = rewards[idx]
            if reward_idx < rewards_boundary:
                continue

            training_obs.extend(obs[idx])
            training_actions.extend(actions[idx])

        obs_tensor = torch.FloatTensor(training_obs).permute(0, 3, 1, 2)
        act_tensor = torch.LongTensor(training_actions)

        return obs_tensor, act_tensor, rewards_mean, rewards_boundary

    def run_val(self, mapnames, BATCH_SIZE):

        batch = {"rewards": [], "obs": [], "actions": []}
        episode_rewards = 0.0
        episode_steps = []

        iter_idx = 0
        while True:

            mapname = np.random.choice(mapnames)

            # self.basepath = self.homepath + f"/{mapname}_{episode}"
            # print("BASEPATH: ", self.basepath)

            # # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            # if not os.path.exists(self.basepath):
            #     os.mkdir(self.basepath)

            self.controller = Controller(
                scene=mapname,
                gridSize=0.25,
                width=self.W,
                height=self.H,
                fieldOfView=self.fov,
                renderObjectImage=True,
                renderDepthImage=True,
            )

            episode_rewards, obs, actions = self.run()

            print("Total reward for val batch # ", iter_idx, " :",
                  episode_rewards)

            self.controller.stop()
            time.sleep(1)

            if episode_rewards is None:
                print("NO EPISODE REWARDS.. SKIPPING BATCH INSTANCE")
                continue

            batch["rewards"].append(episode_rewards)
            batch["obs"].append(obs)
            batch["actions"].append(actions)

            train, labels, mean_rewards, boudary = self.elite_batch(
                batch, self.percentile)

            with torch.no_grad():
                scores = self.cemnet(train, softmax=False)

            loss_value = self.loss(scores, labels)

            iter_idx += 1

            if len(batch["rewards"]) == BATCH_SIZE:

                break

        return loss_value, mean_rewards

    def run_episodes(self):
        self.ep_idx = 0
        # self.objects = []

        train_loss = []
        val_loss = []
        train_iters = []
        val_iters = []
        mean_rewards_train = []
        mean_rewards_val = []
        if self.plot_loss:
            # plt.figure(1) # loss
            # plt.figure(2) # reward
            fig, (ax1, ax2) = plt.subplots(1, 2)
        for iteration, batch in enumerate(
                self.batch_iteration(self.mapnames_train, self.cemnet,
                                     self.BATCH_SIZE)):
            print("ITERATION #", iteration)

            self.summ_writer = utils.improc.Summ_writer(writer=self.writer,
                                                        global_step=iteration,
                                                        log_freq=self.log_freq,
                                                        fps=8,
                                                        just_gif=True)

            train, labels, mean_rewards, boudary = self.elite_batch(
                batch, self.percentile)
            mean_rewards_train.append(mean_rewards)

            self.optimizer.zero_grad()

            scores = self.cemnet(train, softmax=False)

            loss_value = self.loss(scores, labels)
            train_loss.append(float(loss_value.clone().detach().cpu().numpy()))
            train_iters.append(iteration)

            back_v = loss_value.backward()

            self.optimizer.step()

            print('rewards mean = ', mean_rewards)
            print('')

            if iteration >= self.max_iters:
                print("MAX ITERS REACHED")
                self.writer.close()
                break

            # if self.plot_loss:
            # plt.figure(1) # loss
            # fig.clf()

            # plt.figure(2) # reward
            # plt.clf()

            if iteration % self.val_interval == 0:
                loss_val, mean_reward_v = self.run_val(self.mapnames_val,
                                                       self.BATCH_SIZE)
                val_loss.append(float(loss_val.clone().detach().cpu().numpy()))
                val_iters.append(iteration)
                mean_rewards_val.append(mean_reward_v)
                if self.plot_loss:
                    # plt.figure(1) # loss
                    ax1.plot(val_iters, val_loss, color='red')

                    # plt.figure(2) # reward
                    ax2.plot(val_iters, mean_rewards_val, color='red')
                    self.summ_writer.summ_scalar('unscaled_loss_val', loss_val)
                    self.summ_writer.summ_scalar('unscaled_mean_reward_val',
                                                 mean_reward_v)

            if iteration % self.save_interval == 0:
                PATH = self.homepath + f'/checkpoint{iteration}.tar'
                torch.save(self.cemnet.state_dict(), PATH)

            if self.plot_loss:
                self.summ_writer.summ_scalar('unscaled_loss', loss_value)
                self.summ_writer.summ_scalar('unscaled_mean_reward',
                                             mean_rewards)

                # plt.figure(1) # loss
                ax1.plot(train_iters, train_loss, color='blue')
                ax1.set(xlabel='iterations', ylabel='loss')
                # ax1.xlabel('iterations')
                # ax1.ylabel('loss')
                # plt_name = '/home/nel/gsarch/aithor/data/test/loss.png'
                # fig.savefig(plt_name)

                # plt.figure(2) # reward
                ax2.plot(train_iters, mean_rewards_train, color='blue')
                ax2.set(xlabel='iterations', ylabel='reward')
                # ax2.xlabel('iterations')
                # ax2.ylabel('reward')

                plt_name = os.path.join(self.homepath, 'train.png')
                fig.savefig(plt_name)

            # if mean_rewards > 500:
            #     print('Accomplished!')
            #     break

    def save_datapoint(self, observations, data_path, viewnum, flat_view):
        if self.verbose:
            print("Print Sensor States.", self.agent.state.sensor_states)
        rgb = observations["color_sensor"]
        semantic = observations["semantic_sensor"]
        # st()
        depth = observations["depth_sensor"]
        agent_pos = observations["positions"]
        agent_rot = observations["rotations"]
        # Assuming all sensors have same extrinsics
        color_sensor_pos = observations["positions"]
        color_sensor_rot = observations["rotations"]
        #print("POS ", agent_pos)
        #print("ROT ", color_sensor_rot)
        object_list = observations['object_list']

        # print(viewnum, agent_pos)
        # print(agent_rot)

        if False:
            plt.imshow(rgb)
            plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{viewnum}.png'
            plt.savefig(plt_name)

        save_data = {
            'flat_view': flat_view,
            'objects_info': object_list,
            'rgb_camX': rgb,
            'depth_camX': depth,
            'sensor_pos': color_sensor_pos,
            'sensor_rot': color_sensor_rot
        }

        with open(os.path.join(data_path, str(viewnum) + ".p"), 'wb') as f:
            pickle.dump(save_data, f)
        f.close()

    def quat_from_angle_axis(self, theta: float,
                             axis: np.ndarray) -> np.quaternion:
        r"""Creates a quaternion from angle axis format

        :param theta: The angle to rotate about the axis by
        :param axis: The axis to rotate about
        :return: The quaternion
        """
        axis = axis.astype(np.float)
        axis /= np.linalg.norm(axis)
        return quaternion.from_rotation_vector(theta * axis)

    def safe_inverse_single(self, a):
        r, t = self.split_rt_single(a)
        t = np.reshape(t, (3, 1))
        r_transpose = r.T
        inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
        bottom_row = a[3:4, :]  # this is [0, 0, 0, 1]
        # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)
        inv = np.concatenate([inv, bottom_row], 0)
        return inv

    def split_rt_single(self, rt):
        r = rt[:3, :3]
        t = np.reshape(rt[:3, 3], 3)
        return r, t

    def eul2rotm(self, rx, ry, rz):
        # inputs are shaped B
        # this func is copied from matlab
        # R = [  cy*cz   sy*sx*cz-sz*cx    sy*cx*cz+sz*sx
        #        cy*sz   sy*sx*sz+cz*cx    sy*cx*sz-cz*sx
        #        -sy            cy*sx             cy*cx]
        # rx = np.expand_dims(rx, axis=1)
        # ry = np.expand_dims(ry, axis=1)
        # rz = np.expand_dims(rz, axis=1)
        # st()
        # these are B x 1
        sinz = np.sin(rz)
        siny = np.sin(ry)
        sinx = np.sin(rx)
        cosz = np.cos(rz)
        cosy = np.cos(ry)
        cosx = np.cos(rx)
        r11 = cosy * cosz
        r12 = sinx * siny * cosz - cosx * sinz
        r13 = cosx * siny * cosz + sinx * sinz
        r21 = cosy * sinz
        r22 = sinx * siny * sinz + cosx * cosz
        r23 = cosx * siny * sinz - sinx * cosz
        r31 = -siny
        r32 = sinx * cosy
        r33 = cosx * cosy

        r = np.array([[r11, r12, r13], [r21, r22, r23], [r31, r32, r33]])
        # r1 = np.stack([r11,r12,r13],axis=2)
        # r2 = np.stack([r21,r22,r23],axis=2)
        # r3 = np.stack([r31,r32,r33],axis=2)
        # r = np.concatenate([r1,r2,r3],axis=1)
        return r

    def rotm2eul(self, r):
        # r is Bx3x3, or Bx4x4
        r00 = r[0, 0]
        r10 = r[1, 0]
        r11 = r[1, 1]
        r12 = r[1, 2]
        r20 = r[2, 0]
        r21 = r[2, 1]
        r22 = r[2, 2]

        ## python guide:
        # if sy > 1e-6: # singular
        #     x = math.atan2(R[2,1] , R[2,2])
        #     y = math.atan2(-R[2,0], sy)
        #     z = math.atan2(R[1,0], R[0,0])
        # else:
        #     x = math.atan2(-R[1,2], R[1,1])
        #     y = math.atan2(-R[2,0], sy)
        #     z = 0

        sy = np.sqrt(r00 * r00 + r10 * r10)

        cond = (sy > 1e-6)
        rx = np.where(cond, np.arctan2(r21, r22), np.arctan2(-r12, r11))
        ry = np.where(cond, np.arctan2(-r20, sy), np.arctan2(-r20, sy))
        rz = np.where(cond, np.arctan2(r10, r00), np.zeros_like(r20))

        # rx = torch.atan2(r21, r22)
        # ry = torch.atan2(-r20, sy)
        # rz = torch.atan2(r10, r00)
        # rx[cond] = torch.atan2(-r12, r11)
        # ry[cond] = torch.atan2(-r20, sy)
        # rz[cond] = 0.0
        return rx, ry, rz

    def get_habitat_pix_T_camX(self, fov):
        hfov = float(self.fov) * np.pi / 180.
        pix_T_camX = np.array([[
            (self.W / 2.) * 1 / np.tan(hfov / 2.), 0., 0., 0.
        ], [0., (self.H / 2.) * 1 / np.tan(hfov / 2.), 0., 0.], [0., 0., 1, 0],
                               [0., 0., 0, 1]])
        return pix_T_camX

    def get_camera_matrix(self, width, height, fov):
        """Returns a camera matrix from image size and fov."""
        xc = (width - 1.) / 2.
        zc = (height - 1.) / 2.
        f = (width / 2.) / np.tan(np.deg2rad(fov / 2.))
        camera_matrix = {'xc': xc, 'zc': zc, 'f': f}
        camera_matrix = Namespace(**camera_matrix)
        return camera_matrix

    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    def get_detectron_conf_center_obj(self, im, frame=None):
        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        # plt.imshow(im)
        # plt.show()

        outputs = self.maskrcnn(im)

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.2)
        out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + '/seg{frame}.png'
            plt.savefig(plt_name)
            # plt.show()

        pred_masks = outputs['instances'].pred_masks
        pred_scores = outputs['instances'].scores
        pred_classes = outputs['instances'].pred_classes

        len_pad = 5

        W2_low = self.W // 2 - len_pad
        W2_high = self.W // 2 + len_pad
        H2_low = self.H // 2 - len_pad
        H2_high = self.H // 2 + len_pad

        ind_obj = None
        max_overlap = 0
        for idx in range(pred_masks.shape[0]):
            pred_mask_cur = pred_masks[idx]
            pred_masks_center = pred_mask_cur[W2_low:W2_high, H2_low:H2_high]
            # print(torch.sum(pred_masks_center))
            if torch.sum(pred_masks_center) > max_overlap:
                ind_obj = idx
                max_overlap = torch.sum(pred_masks_center)
        if ind_obj is None:
            return None

        print("OBJ CLASS ID=",
              int(pred_classes[ind_obj].detach().cpu().numpy()))
        # pred_boxes = outputs['instances'].pred_boxes.tensor
        # pred_classes = outputs['instances'].pred_classes
        # pred_scores = outputs['instances'].scores
        obj_score = pred_scores[ind_obj]

        return obj_score

    def detect_object_centroid(self, im, event):

        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        # plt.imshow(im)
        # plt.show()

        outputs = self.maskrcnn(im)

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.2)
        out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + '/seg_init.png'
            plt.savefig(plt_name)
            # plt.show()

        pred_masks = outputs['instances'].pred_masks
        pred_boxes = outputs['instances'].pred_boxes.tensor
        pred_classes = outputs['instances'].pred_classes
        pred_scores = outputs['instances'].scores

        # obj_ids = []
        obj_catids = []
        obj_scores = []
        obj_masks = []
        # obj_all_catids = []
        # obj_all_scores = []
        # obj_all_boxes = []
        for segs in range(len(pred_masks)):
            if pred_scores[segs] <= self.conf_thresh_detect:
                # obj_ids.append(segs)
                obj_catids.append(pred_classes[segs].item())
                obj_scores.append(pred_scores[segs].item())
                obj_masks.append(pred_masks[segs])

                # obj_all_catids.append(pred_classes[segs].item())
                # obj_all_scores.append(pred_scores[segs].item())
                # y, x = torch.where(pred_masks[segs])
                # pred_box = torch.Tensor([min(y), min(x), max(y), max(x)]) # ymin, xmin, ymax, xmaxs
                # obj_all_boxes.append(pred_box)

        # print("MASKS ", len(pred_masks))
        # print("VALID ", len(obj_scores))
        # print(obj_scores)
        # print(pred_scores.shape)

        eulers_xyz_rad = np.radians(
            np.array([
                event.metadata['agent']['cameraHorizon'],
                event.metadata['agent']['rotation']['y'], 0.0
            ]))

        rx = eulers_xyz_rad[0]
        ry = eulers_xyz_rad[1]
        rz = eulers_xyz_rad[2]
        rotation_ = self.eul2rotm(-rx, -ry, rz)

        translation_ = np.array(
            list(event.metadata['agent']['position'].values())) + np.array(
                [0.0, 0.675, 0.0])
        # need to invert since z is positive here by convention
        translation_[2] = -translation_[2]

        T_world_cam = np.eye(4)
        T_world_cam[0:3, 0:3] = rotation_
        T_world_cam[0:3, 3] = translation_

        if not obj_masks:
            return None
        elif self.center_from_mask:

            # want an object not on the edges of the image
            sum_interior = 0
            while sum_interior == 0:
                if len(obj_masks) == 0:
                    return None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])

            depth = event.depth_frame

            xs, ys = np.meshgrid(np.linspace(-1 * 256 / 2., 1 * 256 / 2., 256),
                                 np.linspace(1 * 256 / 2., -1 * 256 / 2., 256))
            depth = depth.reshape(1, 256, 256)
            xs = xs.reshape(1, 256, 256)
            ys = ys.reshape(1, 256, 256)

            xys = np.vstack(
                (xs * depth, ys * depth, -depth, np.ones(depth.shape)))
            xys = xys.reshape(4, -1)
            xy_c0 = np.matmul(np.linalg.inv(self.K), xys)
            xyz = xy_c0.T[:, :3].reshape(256, 256, 3)
            xyz_obj_masked = xyz[obj_mask_focus]

            xyz_obj_masked = np.matmul(
                rotation_, xyz_obj_masked.T) + translation_.reshape(3, 1)
            xyz_obj_mid = np.mean(xyz_obj_masked, axis=1)

            xyz_obj_mid[2] = -xyz_obj_mid[2]
        else:

            # want an object not on the edges of the image
            sum_interior = 0
            while True:
                if len(obj_masks) == 0:
                    return None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])
                if sum_interior < 500:
                    continue  # exclude too small objects

                pixel_locs_obj = np.where(obj_mask_focus)
                x_mid = np.round(np.median(pixel_locs_obj[1]) / self.W, 4)
                y_mid = np.round(np.median(pixel_locs_obj[0]) / self.H, 4)

                if False:
                    plt.figure(1)
                    plt.clf()
                    plt.imshow(obj_mask_focus)
                    plt.plot(np.median(pixel_locs_obj[1]),
                             np.median(pixel_locs_obj[0]), 'x')
                    plt_name = self.homepath + '/seg_mask.png'
                    plt.savefig(plt_name)

                event = self.controller.step('TouchThenApplyForce',
                                             x=x_mid,
                                             y=y_mid,
                                             handDistance=1000000.0,
                                             direction=dict(x=0.0,
                                                            y=0.0,
                                                            z=0.0),
                                             moveMagnitude=0.0)
                obj_focus_id = event.metadata['actionReturn']['objectId']

                xyz_obj_mid = None
                for o in event.metadata['objects']:
                    if o['objectId'] == obj_focus_id:
                        xyz_obj_mid = np.array(
                            list(o['axisAlignedBoundingBox']
                                 ['center'].values()))

                if xyz_obj_mid is not None:
                    break
            # if xyz_obj_mid is None:
            #     st()

        print("MIDPOINT=", xyz_obj_mid)
        return xyz_obj_mid

        # semantic = event.instance_segmentation_frame
        # object_id_to_color = event.object_id_to_color
        # color_to_object_id = event.color_to_object_id

        # obj_ids = np.unique(semantic.reshape(-1, semantic.shape[2]), axis=0)

        # instance_masks = event.instance_masks
        # instance_detections2D = event.instance_detections2D

        # obj_metadata_IDs = []
        # for obj_m in event.metadata['objects']: #objects:
        #     obj_metadata_IDs.append(obj_m['objectId'])

        # object_list = []
        # for obj_idx in range(obj_ids.shape[0]):
        #     try:
        #         obj_color = tuple(obj_ids[obj_idx])
        #         object_id = color_to_object_id[obj_color]
        #     except:
        #         # print("Skipping ", object_id)
        #         continue

        #     if object_id not in obj_metadata_IDs:
        #         # print("Skipping ", object_id)
        #         continue

        #     obj_meta_index = obj_metadata_IDs.index(object_id)
        #     obj_meta = event.metadata['objects'][obj_meta_index]
        #     obj_category_name = obj_meta['objectType']

        #     # continue if not visible or not in include classes
        #     if obj_category_name not in self.include_classes: # or not obj_meta['visible']:
        #         continue

        #     obj_instance_mask = instance_masks[object_id]
        #     obj_instance_detection2D = instance_detections2D[object_id] # [start_x, start_y, end_x, end_y]
        #     obj_instance_detection2D = np.array([obj_instance_detection2D[1], obj_instance_detection2D[0], obj_instance_detection2D[3], obj_instance_detection2D[2]])  # ymin, xmin, ymax, xmax

        #     if True:
        #         print(object_id)
        #         print(np.array(list(obj_meta['axisAlignedBoundingBox']['center'].values())))
        #         # plt.imshow(obj_instance_mask)
        #         # plt_name = f'/home/nel/gsarch/aithor/data/test/img_mask{s}.png'
        #         # plt.savefig(plt_name)

        #     obj_center_axisAligned = np.array(list(obj_meta['axisAlignedBoundingBox']['center'].values()))
        #     obj_center_axisAligned[2] = -obj_center_axisAligned[2]
        #     obj_size_axisAligned = np.array(list(obj_meta['axisAlignedBoundingBox']['size'].values()))

    def get_rotation_to_obj(self, obj_center, pos_s):
        # YAW calculation - rotate to object
        agent_to_obj = np.squeeze(obj_center) - pos_s
        agent_local_forward = np.array([0, 0, 1.0])
        flat_to_obj = np.array([agent_to_obj[0], 0.0, agent_to_obj[2]])
        flat_dist_to_obj = np.linalg.norm(flat_to_obj)
        flat_to_obj /= flat_dist_to_obj

        det = (flat_to_obj[0] * agent_local_forward[2] -
               agent_local_forward[0] * flat_to_obj[2])
        turn_angle = math.atan2(det, np.dot(agent_local_forward, flat_to_obj))

        turn_yaw = np.degrees(turn_angle)

        turn_pitch = -np.degrees(math.atan2(agent_to_obj[1], flat_dist_to_obj))

        return turn_yaw, turn_pitch

    def run(self):

        event = self.controller.step('GetReachablePositions')
        if not event.metadata['reachablePositions']:
            # Different versions this is empty/full
            event = self.controller.step(action='MoveAhead')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        np.random.seed(1)
        # objects = np.random.choice(event.metadata['objects'], self.obj_per_scene, replace=False)
        objects = event.metadata['objects']
        objects_inds = np.arange(len(event.metadata['objects']))
        np.random.shuffle(objects_inds)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        successes = 0
        meta_obj_idx = 0
        while True:  #successes < self.obj_per_scene and meta_obj_idx <= len(event.metadata['objects']) - 1:
            if meta_obj_idx > len(event.metadata['objects']) - 1:
                print("OUT OF OBJECT... RETURNING")
                return None, None, None

            obj = objects[objects_inds[meta_obj_idx]]
            meta_obj_idx += 1
            print("Center object is ", obj['objectType'])
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue

            # Calculate distance to object center
            obj_center = np.array(
                list(obj['axisAlignedBoundingBox']['center'].values()))

            obj_center = np.expand_dims(obj_center, axis=0)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where(
                (distances > self.radius_min) * (distances < self.radius_max))]

            # add height from center of agent to camera
            rand_pos_int = np.random.randint(low=0, high=valid_pts.shape[0])
            pos_s = valid_pts[rand_pos_int]
            pos_s[1] = pos_s[1] + 0.675

            turn_yaw, turn_pitch = self.get_rotation_to_obj(obj_center, pos_s)
            event = self.controller.step('TeleportFull',
                                         x=pos_s[0],
                                         y=pos_s[1],
                                         z=pos_s[2],
                                         rotation=dict(x=0.0,
                                                       y=int(turn_yaw),
                                                       z=0.0),
                                         horizon=int(turn_pitch))
            rgb = event.frame

            # get object center of a low confidence object
            obj_center = self.detect_object_centroid(rgb, event)

            if obj_center is None:
                print("NO LOW CONFIDENCE OBJECTS... SKIPPING...")
                continue

            # initialize object in center of FOV
            turn_yaw, turn_pitch = self.get_rotation_to_obj(obj_center, pos_s)
            event = self.controller.step('TeleportFull',
                                         x=pos_s[0],
                                         y=pos_s[1],
                                         z=pos_s[2],
                                         rotation=dict(x=0.0,
                                                       y=int(turn_yaw),
                                                       z=0.0),
                                         horizon=int(turn_pitch))
            rgb = event.frame
            init_conf = self.get_detectron_conf_center_obj(rgb)
            if init_conf is None:
                print("Nothing detected in the center... SKIPPING")
                continue
            conf_prev = init_conf
            if init_conf > self.conf_thresh_init:
                print("HIGH INITIAL CONFIDENCE... SKIPPING...")
                continue

            obs = []
            actions = []
            episode_rewards = 0.0
            frame = 0
            while True:

                rgb_tensor = torch.FloatTensor([rgb]).permute(0, 3, 1, 2)

                with torch.no_grad():
                    actions_probability = self.cemnet(rgb_tensor, softmax=True)

                act_proba = actions_probability.data.numpy()[0]

                action_ind = np.random.choice(len(act_proba), p=act_proba)

                action = self.action_space[action_ind]

                event = self.controller.step(action)
                rgb = event.frame

                conf_cur = self.get_detectron_conf_center_obj(rgb, frame)
                if conf_cur is None:
                    reward = -1  #-0.2 # fixed negative reward for no detection
                    conf_cur = conf_prev
                    # conf_prev = conf_prev # use same conf
                else:
                    # reward = (conf_cur - conf_prev)/(1 - init_conf) # normalize by intial confidence to account for differences in starting confidence
                    diff = conf_cur - conf_prev
                    if diff > 0:
                        reward = 1
                    elif diff == 0:
                        reward = 0
                    else:
                        reward = -1
                conf_prev = conf_cur

                episode_rewards += reward

                obs.append(rgb)
                actions.append(action_ind)

                if conf_cur > self.conf_thresh_end:
                    print("CONFIDENCE THRESHOLD REACHED!")
                    print("End confidence: ", conf_cur)
                    break

                if frame >= self.max_frames - 1:
                    print("MAX FRAMES REACHED")
                    print("End confidence: ", conf_cur)
                    break

                frame += 1
            return episode_rewards, obs, actions
Пример #13
0
class Actor:
    """
    The basic THOR actor class that we use. This process can be controlled via pipes. The main functionality is to
    provide RGBD views from THOR scenes, receive and perform pokes in these scenes, and provide the feedback.

    The feedback can be in the form of raw images, or already processed (for performance). The agent can run in standard
    mode (images are only extracted from iTHOR when the scene is at rest), or in video mode (images are extracted in
    real time; very slow).

    Additionally, this class can provide superpixel segmentations of the THOR scenes.
    """
    def __init__(self, pipe, gpu, configs, ground_truth=0, run=True):
        self.pipe = pipe
        self.gc, self.ac = configs
        self.ground_truth = ground_truth
        self.directions = [
            dict(x=a, y=b, z=c) for a in [-1, 1] for b in [1, 2]
            for c in [-1, 1]
        ]
        self.controller = Controller(
            x_display='0.%d' % gpu,
            visibilityDistance=self.ac.visibilityDistance,
            renderDepthImage=self.gc.depth or ground_truth > 0,
            renderClassImage=ground_truth > 0,
            renderObjectImage=ground_truth > 0)
        self.grid = [(x, y) for x in range(self.gc.grid_size)
                     for y in range(self.gc.grid_size)]

        self.depth_correction = self._make_depth_correction(
            self.gc.resolution, self.gc.resolution, 90)

        self.kernel_size = (self.ac.check_change_kernel.shape[0] - 1) // 2

        if run:
            self.run()

    @staticmethod
    def _make_depth_correction(height, width, vertical_field_of_view):
        focal_length_in_pixels = height / np.tan(
            vertical_field_of_view / 2 * np.pi / 180) / 2
        xs = (np.arange(height).astype(np.float32) - (height - 1) / 2)**2
        ys = (np.arange(width).astype(np.float32) - (width - 1) / 2)**2
        return np.sqrt(1 +
                       (xs[:, None] + ys[None, :]) / focal_length_in_pixels**2)

    def run(self):
        while True:
            action = self.pipe.recv()
            if action == 'stop':
                break
            elif type(action) == dict:
                data = dict()
                image, superpixel, metadata = self.set_scene(
                    action) if self.ac.use_dataset else self.set_scene0(action)
                data['image'] = image
                if self.gc.superpixels:
                    data['superpixels'] = superpixel
                if self.ground_truth:
                    data['metadata'] = metadata
                self.pipe.send(data)
            if type(action) == list:
                feedback = self.poke(action, superpixel)
                self.pipe.send(feedback)
        self.controller.stop()
        self.pipe.send('stop')

    def set_scene(self, scene_data):
        if self.ac.video_mode:
            self.controller.step(action='UnpausePhysicsAutoSim')
        scene, position, rotation, horizon, seed = (scene_data['scene'],
                                                    scene_data['position'],
                                                    scene_data['rotation'],
                                                    scene_data['horizon'],
                                                    scene_data['seed'])
        self.controller.reset(scene)
        self.controller.step(action='InitialRandomSpawn',
                             seed=seed,
                             forceVisible=True,
                             numPlacementAttempts=5)
        self.controller.step(action='MakeAllObjectsMoveable')

        event = self.controller.step(action='TeleportFull',
                                     x=position['x'],
                                     y=position['y'],
                                     z=position['z'],
                                     rotation=rotation,
                                     horizon=horizon)

        image = event.frame

        if self.gc.superpixels:
            superpixel = felzenszwalb(
                image, scale=200, sigma=.5,
                min_size=200)[::self.gc.stride, ::self.gc.stride].astype(
                    np.int32)
        else:
            superpixel = None

        if self.gc.depth:
            depth = event.depth_frame
            if self.gc.correct_depth:
                depth = (
                    depth - .1
                ) * self.depth_correction  # convert depth from camera plane to distance from camera
            image = (image, depth)

        if self.ground_truth:
            ground_truth = self.compute_ground_truth(event)
            metadata = (ground_truth, seed, position, rotation, horizon)
        else:
            metadata = None

        return image, superpixel, metadata

    def set_scene0(self, scene):
        if self.ac.video_mode:
            self.controller.step(action='UnpausePhysicsAutoSim')
        self.controller.reset(scene['scene'])
        seed = randint(0, 2**30)
        self.controller.step(action='InitialRandomSpawn',
                             seed=seed,
                             forceVisible=True,
                             numPlacementAttempts=5)
        self.controller.step(action='MakeAllObjectsMoveable')
        event = self.controller.step(action='GetReachablePositions')
        positions = deepcopy(event.metadata['reachablePositions'])

        position, rotation, horizon = choice(positions), choice(
            [0., 90., 180., 270.]), choice([-30., 0., 30., 60.])
        event = self.controller.step(action='TeleportFull',
                                     x=position['x'],
                                     y=position['y'],
                                     z=position['z'],
                                     rotation=rotation,
                                     horizon=horizon)

        if self.gc.respawn_until_object:
            contains_interactable_object = len([
                o for o in event.metadata['objects']
                if o['visible'] and (o['moveable'] or o['pickupable'])
            ]) > 0
            while not contains_interactable_object:
                position, rotation, horizon = choice(positions), choice([0., 90., 180., 270.]), \
                                              choice([-30., 0., 30., 60.])
                event = self.controller.step(action='TeleportFull',
                                             x=position['x'],
                                             y=position['y'],
                                             z=position['z'],
                                             rotation=rotation,
                                             horizon=horizon)
                contains_interactable_object = len([
                    o for o in event.metadata['objects']
                    if o['visible'] and (o['moveable'] or o['pickupable'])
                ]) > 0

        image = event.frame

        if self.gc.superpixels:
            superpixel = felzenszwalb(
                image, scale=200, sigma=.5,
                min_size=200)[::self.gc.stride, ::self.gc.stride].astype(
                    np.int32)
        else:
            superpixel = None

        if self.gc.depth:
            depth = event.depth_frame
            if self.gc.correct_depth:
                depth = (
                    depth - .1
                ) * self.depth_correction  # convert depth from camera plane to distance from camera
            image = (image, depth)

        if self.ground_truth:
            ground_truth = self.compute_ground_truth(event)
            metadata = (ground_truth, seed, position, rotation, horizon)
        else:
            metadata = None

        return image, superpixel, metadata

    def poke(self, action, superpixel):
        feedback = []

        if self.ac.video_mode:
            im1 = self.controller.step(action='PausePhysicsAutoSim').frame
        else:
            im1 = self.controller.step(action='Pass').frame
        for poke_point in action:
            if self.ac.instance_only:
                im2 = self.touch(poke_point['point'],
                                 self.ac.force_buckets[-1])
                poke_feedback = self.compute_feedback(im1, im2,
                                                      poke_point['point'],
                                                      superpixel)
            elif self.ac.scaleable:
                poke_feedback, im2 = self.touch_with_forces(
                    im1, poke_point, superpixel)
            else:
                poke_feedback, im2 = self.touch_with_forces_nonscaleable(
                    im1, poke_point, superpixel)
            feedback.append(poke_feedback)
            im1 = im2[-1]
        return feedback

    def touch_with_forces(self, im1, point_and_force, superpixel):
        direction = choice(self.directions)
        point, force = point_and_force['point'], point_and_force['force']
        smaller_force = max(force - 1, 0)
        im2 = self.touch(point, self.ac.force_buckets[smaller_force],
                         direction)
        vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
        if self.get_score(vis_feedback, point) > 1.5:
            return (vis_feedback, -1 if force > 0 else 0), im2
        im1 = im2[-1]
        im2 = self.touch(point, self.ac.force_buckets[force], direction)
        vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
        if self.get_score(vis_feedback, point) > 1.5:
            return (vis_feedback, 0), im2
        if force < len(self.ac.force_buckets) - 1:
            im1 = im2[-1]
            im2 = self.touch(point, self.ac.force_buckets[-1], direction)
            vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
            if self.get_score(vis_feedback, point) > 1.5:
                return (vis_feedback, 1), im2
        return (vis_feedback, 2), im2

    def touch_with_forces_nonscaleable(self, im1, point_and_force, superpixel):
        point = point_and_force['point']
        smaller_force = 0
        im2 = self.touch(point, self.ac.force_buckets[smaller_force])
        vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
        if self.get_score(vis_feedback, point) > 1.5:
            return (vis_feedback, 0), im2
        im1 = im2[-1]
        im2 = self.touch(point, self.ac.force_buckets[1])
        vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
        if self.get_score(vis_feedback, point) > 1.5:
            return (vis_feedback, 1), im2
        im1 = im2[-1]
        im2 = self.touch(point, self.ac.force_buckets[-1])
        vis_feedback = self.compute_feedback(im1, im2, point, superpixel)
        if self.get_score(vis_feedback, point) > 1.5:
            return (vis_feedback, 1), im2
        return (vis_feedback, 2), im2

    def touch(self, point, force, direction=None):
        y, x = point  # x axis (=first axis) in numpy is y-axis (=second axis) in images (THOR)
        if direction is None:
            direction = choice(self.directions)
        im2 = []
        if self.ac.video_mode:
            self.controller.step(
                dict(action='TouchThenApplyForce',
                     x=x / self.gc.grid_size + 1 / 2 / self.gc.grid_size,
                     y=y / self.gc.grid_size + 1 / 2 / self.gc.grid_size,
                     direction=direction,
                     handDistance=self.ac.handDistance,
                     moveMagnitude=force))
            im2.append(
                self.controller.step(action='AdvancePhysicsStep',
                                     timeStep=0.01).frame)
            for _ in range(25):
                im2.append(
                    self.controller.step(action='AdvancePhysicsStep',
                                         timeStep=0.05).frame)

        else:
            im2.append(
                self.controller.step(
                    dict(action='TouchThenApplyForce',
                         x=x / self.gc.grid_size + 1 / 2 / self.gc.grid_size,
                         y=y / self.gc.grid_size + 1 / 2 / self.gc.grid_size,
                         direction=direction,
                         handDistance=self.ac.handDistance,
                         moveMagnitude=force)).frame)
        return im2

    def compute_feedback(self, im1, im2, poke_point, superpixel):
        if self.ac.raw_feedback:
            return im1, im2

        if self.ac.hsv:
            im1, im2 = rgb2hsv(im1), [rgb2hsv(im) for im in im2]
        else:
            im1, im2 = im1.astype(
                np.float32), [im.astype(np.float32) for im in im2]

        if self.ac.video_mode:
            diff = self.pca_diff(
                [im1] + im2) if self.ac.pca else self.compute_mean_of([im1] +
                                                                      im2)
        else:
            im2 = im2[0]
            diff = [im1 - im2]

        if self.ac.smooth_mask:
            mask = self.make_smooth_mask(diff, [im1] + im2)
        else:
            mask = self.make_mask(diff)

        if superpixel is not None and self.ac.superpixel_postprocessed_feedback:
            if self.ac.smooth_mask:
                raise ValueError
            mask = self.smooth_mask_over_superpixels(mask, superpixel)
        if self.ac.connectedness_postprocessed_feedback:
            if self.ac.smooth_mask:
                raise ValueError
            mask = self.connected_component(mask, poke_point)

        return mask

    def make_mask(self, diff):
        diff = diff[0]
        scores = (diff**2).reshape(self.gc.grid_size, self.gc.stride,
                                   self.gc.grid_size, self.gc.stride,
                                   3).mean(axis=(1, 3, 4))

        mask = scores > self.ac.pixel_change_threshold
        return mask

    def smooth_mask_over_superpixels(self, mask, superpixels):
        smoothed_mask = np.zeros_like(mask)
        superpixels = [superpixels == i for i in np.unique(superpixels)]
        for superpixel in superpixels:
            if mask[superpixel].sum() / superpixel.sum(
            ) > self.ac.superpixel_postprocessing_threshold:
                smoothed_mask[superpixel] = True
        return smoothed_mask

    def connected_component(self, mask, poke_point):
        x, y = poke_point
        b = mask[x, y]
        mask[x, y] = True
        fat_mask = self.fatten(mask) if self.ac.fatten else mask
        labels = label(fat_mask) * mask
        i = labels[x, y]
        labels[x, y] *= b
        mask[x, y] = b
        return labels == i

    # Below: Functionality for computing videoPCA soft masks

    def make_smooth_mask(self, diff, video):
        with torch.no_grad():
            diff = torch.from_numpy(diff.transpose(0, 3, 1, 2)).float()
            diff = torch.sqrt((diff**2).sum(dim=1)).unsqueeze(1)
            diff = torch.nn.functional.conv2d(
                diff,
                self.ac.kernel,
                padding=(self.ac.kernel.shape[-1] - 1) // 2)
            if (diff[0] > self.ac.soft_mask_threshold).sum() == 0:
                return np.zeros(video[0].shape[:-1], dtype=np.float32)
            diff = self.bn_torch(diff)
            mask = diff.squeeze(1) > .5
            mask_np = mask.numpy()
            return self.color_histogram_soft_mask(mask_np, mask, video)

    def color_histogram_soft_mask(self, mask, mask_torch, video):
        if mask[0].sum() == 0 or (~mask[0]).sum() == 0:
            return mask[0].astype(np.float32)
        if not self.ac.hsv:
            video = [rgb2hsv(im) for im in video]
        video = self.combine_colors(video)
        fg_histogram = np.histogram(video[mask],
                                    bins=self.ac.num_bins,
                                    density=True)[0]
        bg_histogram = np.histogram(
            video[~mask], bins=self.ac.num_bins, density=True)[0] + 1e-5
        image = video[0].reshape(-1)
        soft_mask = fg_histogram[image] / (fg_histogram[image] +
                                           bg_histogram[image])
        soft_mask = soft_mask.astype(np.float32).reshape(video[0].shape)
        soft_mask = self.center_soft_masks(soft_mask, mask_torch[0])
        return self.bn_np(self.hysteresis_threshold(self.bn_np(soft_mask)))

    def combine_colors(self, video):
        video = np.stack(video)
        ret = np.zeros(video.shape[:-1], dtype=np.int)
        ret += ((video[..., 1] * np.cos(2 * np.pi * video[..., 0]) + 1) / 2 *
                (self.ac.colres1 - 1)).astype(np.int)
        ret += ((video[..., 1] * np.sin(2 * np.pi * video[..., 0]) + 1) / 2 *
                (self.ac.colres2 - 1)).astype(np.int) * self.ac.colres1
        ret += (video[..., 2] * (self.ac.colres3 - 1)).astype(
            np.int) * self.ac.colres1 * self.ac.colres2
        return ret

    def hysteresis_threshold(self, mask):
        thresholding = self.fatten(
            self.fatten(mask > self.ac.hyst_thresholds[0]))
        thresholding *= mask > self.ac.hyst_thresholds[1]
        return mask * thresholding

    @staticmethod
    def compute_mean_of(images):
        mean = sum(images) / len(images)
        return np.stack([im - mean for im in images])

    def pca_diff(self, video):
        video = np.stack(video)
        bs = video.shape[0]
        video = video.reshape(
            (bs, self.gc.grid_size, self.gc.stride, self.gc.grid_size,
             self.gc.stride, 3)).mean(axis=(2, 4))
        video_shape = video[0].shape
        video = video.reshape((bs, -1))
        pca = PCA(n_components=self.ac.num_pca_components)
        pca.fit(video)
        reconstruction = pca.inverse_transform(pca.transform(video))
        diff = video - reconstruction
        return diff.reshape(*((bs, ) + video_shape))

    def center_soft_masks(self, soft_mask, mask):
        with torch.no_grad():
            mask = mask.float().unsqueeze(0).unsqueeze(1)
            mask = torch.nn.functional.conv2d(
                mask,
                self.ac.centering_kernel,
                padding=(self.ac.centering_kernel.shape[-1] - 1) //
                2).squeeze() > 1
        return soft_mask * mask.numpy()

    @staticmethod
    def fatten(mask):
        fat_mask = mask.copy()
        fat_mask[:-1] = fat_mask[:-1] | mask[1:]
        fat_mask[1:] = fat_mask[1:] | mask[:-1]
        fat_mask[:, :-1] = fat_mask[:, :-1] | mask[:, 1:]
        fat_mask[:, 1:] = fat_mask[:, 1:] | mask[:, :-1]
        return fat_mask

    @staticmethod
    def bn_torch(array):
        bs = array.shape[0]
        unsqueeze = (bs, ) + (1, ) * (len(array.shape) - 1)
        diffmin = array.view(bs, -1).min(1)[0].view(*unsqueeze)
        diffmax = array.view(bs, -1).max(1)[0].view(*unsqueeze)
        return (array - diffmin) / (diffmax - diffmin + 1e-6)

    @staticmethod
    def bn_np(array):
        bs = array.shape[0]
        unsqueeze = (bs, ) + (1, ) * (len(array.shape) - 1)
        diffmin = array.reshape((bs, -1)).min(1).reshape(unsqueeze)
        diffmax = array.reshape((bs, -1)).max(1).reshape(unsqueeze)
        return (array - diffmin) / (diffmax - diffmin + 1e-6)

    def compute_ground_truth(self, event):
        depth = (event.depth_frame - 0.1) * self.depth_correction
        reachable_pixels = depth < self.ac.handDistance

        keys = [
            o['objectId'] for o in event.metadata['objects']
            if o['objectId'] in event.instance_masks.keys() and o['visible']
            and (o['moveable'] or o['pickupable'])
            and o['mass'] < self.ac.mass_threshold
        ]
        masses_unf = [
            self.round_mass(o['mass']) for o in event.metadata['objects']
            if o['objectId'] in event.instance_masks.keys() and o['visible']
            and (o['moveable'] or o['pickupable'])
            and o['mass'] < self.ac.mass_threshold
        ]

        masks_unf = [event.instance_masks[key] for key in keys]
        masks, masses = [], []
        for mask, mass in zip(masks_unf, masses_unf):
            if self.ac.max_pixel_threshold > (
                    mask *
                    reachable_pixels).sum() > self.ac.min_pixel_threshold:
                masks.append(mask)
                masses.append(mass)

        poking_points = [
            np.stack(np.where(mask * reachable_pixels)) for mask in masks
        ]
        poking_points = [[
            self.round(*tuple(points[:, i]))
            for i in sample(range(points.shape[1]),
                            k=min(points.shape[1], self.ac.max_poke_attempts))
        ] for points in poking_points]
        return masks, poking_points, masses

    def round(self, x, y):
        return x // self.gc.stride, y // self.gc.stride

    def round_mass(self, mass):
        if mass < self.ac.mass_buckets[0]:
            return 0
        if mass < self.ac.mass_buckets[1]:
            return 1
        return 2

    def get_score(self, mask, action):
        x, y = action
        dx1 = min(x, self.kernel_size)
        dx2 = min(self.gc.grid_size - 1 - x, self.kernel_size) + 1
        dy1 = min(y, self.kernel_size)
        dy2 = min(self.gc.grid_size - 1 - y, self.kernel_size) + 1
        x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2
        return (mask[x1:x2, y1:y2] * self.ac.check_change_kernel[
            self.kernel_size - dx1:self.kernel_size + dx2,
            self.kernel_size - dy1:self.kernel_size + dy2]).sum()
Пример #14
0
class Ai2Thor():
    def __init__(self):   
        self.visualize = False
        self.verbose = False
        self.save_imgs = True

        self.plot_loss = True
        # st()

        # these are all map names
        a = np.arange(1, 30)
        b = np.arange(201, 231)
        c = np.arange(301, 331)
        d = np.arange(401, 431)
        abcd = np.hstack((a,b,c,d))
        mapnames = []
        for i in list(abcd):
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        train_len = int(0.9 * len(mapnames))

        random.shuffle(mapnames)
        self.mapnames_train = mapnames[:train_len]
        self.mapnames_val = mapnames[train_len:]
        # self.num_episodes = len(self.mapnames)   

        self.ignore_classes = []  
        # classes to save   
        self.include_classes = [
            'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel', 'HandTowel', 'TowelHolder', 'SoapBar', 
            'ToiletPaper', 'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan', 'Candle', 'ScrubBrush', 
            'Plunger', 'SinkBasin', 'Cloth', 'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed', 'Book', 
            'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil', 'CellPhone', 'KeyChain', 'Painting', 'CreditCard', 
            'AlarmClock', 'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk', 'Curtains', 'Dresser', 
            'Watch', 'Television', 'WateringCan', 'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue', 
            'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat', 'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit', 
            'Shelf', 'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave', 'CoffeeMachine', 'Fork', 'Fridge', 
            'WineBottle', 'Spatula', 'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato', 'PepperShaker', 
            'ButterKnife', 'StoveKnob', 'Toaster', 'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl', 
            'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster', 'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin', 
            'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll', 'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear', 
            'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window']

        self.include_classes_final = [
            'Sink', 
            'Toilet', 'Bed', 'Book', 
            'CellPhone', 
            'AlarmClock', 'Laptop', 'Chair',
            'Television', 'RemoteControl', 'HousePlant', 
            'Ottoman', 'ArmChair', 'Sofa', 'BaseballBat', 'TennisRacket', 'Mug', 
            'Apple', 'Bottle', 'Microwave', 'Fork', 'Fridge', 
            'WineBottle', 'Cup', 
            'ButterKnife', 'Toaster', 'Spoon', 'Knife', 'DiningTable', 'Bowl', 
            'Vase', 
            'TeddyBear', 'StoveKnob', 'StoveBurner',
            ]

        # self.include_classes = [
        #     'Sink', 
        #     'Toilet', 'Bed', 'Book', 
        #     'CellPhone', 
        #     'AlarmClock', 'Laptop', 'Chair',
        #     'Television', 'RemoteControl', 'HousePlant', 
        #     'Ottoman', 'ArmChair', 'Sofa', 'BaseballBat', 'TennisRacket', 'Mug', 
        #     'Apple', 'Bottle', 'Microwave', 'Fork', 'Fridge', 
        #     'WineBottle', 'Cup', 
        #     'ButterKnife', 'Toaster', 'Spoon', 'Knife', 'DiningTable', 'Bowl', 
        #     'Vase', 
        #     'TeddyBear', 
        #     ]

        self.action_space = {0: "MoveLeft", 1: "MoveRight", 2: "MoveAhead", 3: "MoveBack", 4: "DoNothing"}
        self.num_actions = len(self.action_space)

        cfg_det = get_cfg()
        cfg_det.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg_det.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1  # set threshold for this model
        cfg_det.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        cfg_det.MODEL.DEVICE='cuda'
        self.cfg_det = cfg_det
        self.maskrcnn = DefaultPredictor(cfg_det)

        self.conf_thresh_detect = 0.7 # for initially detecting a low confident object
        self.conf_thresh_init = 0.8 # for after turning head toward object threshold
        self.conf_thresh_end = 0.9 # if reach this then stop getting obs

        self.BATCH_SIZE = 50 #50 # frames (not episodes) - this is approximate - it could be higher 
        # self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 10
        self.val_interval = 10 #10 #10
        self.save_interval = 50

        # self.BATCH_SIZE = 2
        # self.percentile = 70
        # self.max_iters = 100000
        # self.max_frames = 2
        # self.val_interval = 1
        # self.save_interval = 1

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5 #3 #1.75
        self.radius_min = 1.0 #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25
        self.center_from_mask = False # get object centroid from maskrcnn (True) or gt (False)

        self.obj_per_scene = 5

        mod = 'conf05'

        # self.homepath = f'/home/nel/gsarch/aithor/data/test2'
        self.homepath = '/home/sirdome/katefgroup/gsarch/ithor/data/' + mod
        print(self.homepath)
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert(False)

        self.log_freq = 1
        self.log_dir = self.homepath +'/..' + '/log_cem/' + mod
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        MAX_QUEUE = 10 # flushes when this amount waiting
        self.writer = SummaryWriter(self.log_dir, max_queue=MAX_QUEUE, flush_secs=60)


        self.W = 256
        self.H = 256

        self.fov = 90

        self.utils = Utils(self.fov, self.W, self.H)
        self.K = self.utils.get_habitat_pix_T_camX(self.fov)
        self.camera_matrix = self.utils.get_camera_matrix(self.W, self.H, self.fov)

        self.controller = Controller(
            scene='FloorPlan30', # will change 
            gridSize=0.25,
            width=self.W,
            height=self.H,
            fieldOfView= self.fov,
            renderObjectImage=True,
            renderDepthImage=True,
            )

        self.init_network()

        self.run_episodes()
    
    def init_network(self):

        input_shape = np.array([3, self.W, self.H])
        
        self.localpnet = LocalPNET(input_shape=input_shape, num_actions=self.num_actions).cuda()

        self.loss = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(params=self.localpnet.parameters(),lr=0.00001)

    def batch_iteration(self,mapnames,BATCH_SIZE):

        batch = {"actions": [], "obs_all": [], "seg_ims": [], "conf_end_change": [], "conf_avg_change": [], "conf_median_change": []}
        iter_idx = 0
        total_loss = torch.tensor(0.0).cuda()
        num_obs = 0
        while True:

            mapname = np.random.choice(mapnames)

            # self.basepath = self.homepath + f"/{mapname}_{episode}"
            # print("BASEPATH: ", self.basepath)

            # # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            # if not os.path.exists(self.basepath):
            #     os.mkdir(self.basepath)

            self.controller.reset(scene=mapname)

            total_loss, obs, actions, seg_ims, confs = self.run("train", total_loss)

            if obs is None:
                print("NO EPISODE LOSS.. SKIPPING BATCH INSTANCE")
                continue

            num_obs += len(actions)

            print("Total loss for train batch # ",iter_idx," :",total_loss)

            confs = np.array(confs)
            conf_end_change = confs[-1] - confs[0]
            conf_avg_change = np.mean(np.diff(confs))
            conf_median_change = np.median(np.diff(confs))

            batch["actions"].append(actions) 
            # These are only used for plotting
            batch["obs_all"].append(obs)
            batch["seg_ims"].append(seg_ims)
            batch["conf_end_change"].append(conf_end_change)
            batch["conf_avg_change"].append(conf_avg_change)
            batch["conf_median_change"].append(conf_median_change)

            iter_idx += 1   

            # if len(batch["obs_all"]) == BATCH_SIZE:
            if num_obs >= BATCH_SIZE:
                print("NUM OBS IN BATCH=", num_obs)
                # batch["total_loss"] = total_loss
                print("Total loss for iter: ", total_loss)

                return total_loss, batch, num_obs
                # iter_idx = 0
                # total_loss = torch.tensor(0.0).cuda()
                # batch = {"actions": [], "obs_all": [], "seg_ims": [], "conf_end_change": [], "conf_avg_change": []}
                
    # def elite_batch(self,batch,percentile):

    #     rewards = np.array(batch["rewards"])
    #     obs = batch["obs"]
    #     actions = batch["actions"]

    #     rewards_mean = float(np.mean(rewards))
        
    #     rewards_boundary = np.percentile(rewards,percentile)

    #     print("Reward boundary: ", rewards_boundary)

    #     rewards_mean = float(np.mean(rewards))

    #     training_obs = []
    #     training_actions = []

    #     for idx in range(rewards.shape[0]):
    #         reward_idx = rewards[idx]
    #         if reward_idx < rewards_boundary:
    #             continue

    #         training_obs.extend(obs[idx])
    #         training_actions.extend(actions[idx])

    #     obs_tensor = torch.FloatTensor(training_obs).permute(0, 3, 1, 2).cuda()
    #     act_tensor = torch.LongTensor(training_actions).cuda()

    #     return obs_tensor, act_tensor, rewards_mean, rewards_boundary
    
    def run_val(self, mapnames, BATCH_SIZE, summ_writer=None):
        # run validation every x steps

        # batch = {"rewards": [], "obs": [], "actions": []}
        episode_rewards = 0.0
        seg_ims_batch = []
        obs_ims_batch = []

        iter_idx = 0
        while True:

            mapname = np.random.choice(mapnames)
        
            self.controller.reset(scene=mapname)

            _, obs, actions, seg_ims, confs = self.run("val", None)

            # self.controller.stop()
            # time.sleep(1)

            if obs is None:
                print("NO EPISODE REWARDS.. SKIPPING BATCH INSTANCE")
                continue

            seg_ims_batch.append(seg_ims)
            obs_ims_batch.append(obs)
            
            confs = np.array(confs)
            conf_end_change = confs[-1] - confs[0]
            conf_avg_change = np.mean(np.diff(confs))
            conf_median_change = np.median(np.diff(confs))

            print("val confidence change (end-start):", conf_end_change)
            print("val average confidence difference between frames", conf_avg_change)

            # batch["obs"].append(obs[1:]) # first obs is initial pos (for plotting)
            # batch["actions"].append(actions) 

            iter_idx += 1   

            if len(obs_ims_batch) == 1: # only one for val
        
                try:
                    if summ_writer is not None:
                        name = 'inputs_val/rgbs_original'
                        self.summ_writer.summ_imgs_aithor(name,obs_ims_batch, self.W, self.H, self.max_frames)
                        name = 'inputs_val/rgbs_maskrcnn'
                        self.summ_writer.summ_imgs_aithor(name,seg_ims_batch, self.W, self.H, self.max_frames)

                except:
                    print("PLOTTING DIDNT WORK")
                    pass

                break
        
        return conf_end_change, conf_avg_change, conf_median_change

    def run_episodes(self):

        iteration = 0
        while True:
            
            iteration += 1
            print("ITERATION #", iteration)

            self.summ_writer = utils.improc.Summ_writer(
                writer=self.writer,
                global_step=iteration,
                log_freq=self.log_freq,
                fps=8,
                just_gif=True)

            total_loss, batch, num_obs = self.batch_iteration(self.mapnames_train,self.BATCH_SIZE)

            self.optimizer.zero_grad()

            total_loss.backward()

            self.optimizer.step()

            if iteration >= self.max_iters:
                print("MAX ITERS REACHED")
                self.writer.close()
                break

            if iteration % self.val_interval == 0:
                conf_end_change, conf_avg_change, conf_median_change = self.run_val(self.mapnames_val, self.BATCH_SIZE, self.summ_writer)
                if self.plot_loss:
                    self.summ_writer.summ_scalar('val_conf_end_change', conf_end_change)
                    self.summ_writer.summ_scalar('val_conf_avg_change', conf_avg_change)
                    self.summ_writer.summ_scalar('val_conf_median_change', conf_median_change)

            if iteration % self.save_interval == 0:
                PATH = self.homepath + f'/checkpoint{iteration}.tar'
                torch.save(self.localpnet.state_dict(), PATH)
            
            if self.plot_loss:
                conf_end_change_t = np.mean(np.array(batch["conf_end_change"]))
                conf_avg_change_t = np.mean(np.array(batch["conf_avg_change"]))
                conf_median_change_t = np.mean(batch["conf_median_change"])
                self.summ_writer.summ_scalar('train_conf_end_change_batchavg', conf_end_change_t)
                self.summ_writer.summ_scalar('train_conf_avg_change_batchavg', conf_avg_change_t)
                self.summ_writer.summ_scalar('train_conf_median_change_batchavg', conf_median_change_t)
                self.summ_writer.summ_scalar('total_loss', total_loss)
            
            ## PLOTTING #############
            try:
                summ_writer = self.summ_writer
                if summ_writer is not None and (iteration % self.val_interval == 0):
                    obs_ims_batch = batch["obs_all"]
                    seg_ims_batch = batch["seg_ims"]

                    name = 'inputs_train/rgbs_original'
                    self.summ_writer.summ_imgs_aithor(name,obs_ims_batch, self.W, self.H, self.max_frames)
                    name = 'inputs_train/rgbs_maskrcnn'
                    self.summ_writer.summ_imgs_aithor(name,seg_ims_batch, self.W, self.H, self.max_frames)
            except:
                print("PLOTTING DIDNT WORK")
                pass
                
            self.writer.close() # close tensorboard to flush

        self.controller.stop()
        time.sleep(10)
    
    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    
    def get_detectron_conf_center_obj(self,im, obj_mask, frame=None):
        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        pred_masks = outputs['instances'].pred_masks
        pred_scores = outputs['instances'].scores
        pred_classes = outputs['instances'].pred_classes

        len_pad = 5

        W2_low = self.W//2 - len_pad
        W2_high = self.W//2 + len_pad
        H2_low = self.H//2 - len_pad
        H2_high = self.H//2 + len_pad

        if False:

            v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]), scale=1.0)
            out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
            seg_im = out.get_image()
        
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all{frame}.png'
            plt.savefig(plt_name)

            seg_im[W2_low:W2_high, H2_low:H2_high,:] = 0.0
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all_mask{frame}.png'
            plt.savefig(plt_name)

        ind_obj = None
        # max_overlap = 0
        sum_obj_mask = np.sum(obj_mask)
        mask_sum_thresh = 7000
        for idx in range(pred_masks.shape[0]):
            pred_mask_cur = pred_masks[idx].detach().cpu().numpy()
            pred_masks_center = pred_mask_cur[W2_low:W2_high, H2_low:H2_high]
            sum_pred_mask_cur = np.sum(pred_mask_cur)
            # print(torch.sum(pred_masks_center))
            if np.sum(pred_masks_center) > 0:
                if np.abs(sum_pred_mask_cur - sum_obj_mask) < mask_sum_thresh:
                    ind_obj = idx
                    mask_sum_thresh = np.abs(sum_pred_mask_cur - sum_obj_mask)
                # max_overlap = torch.sum(pred_masks_center)
        if ind_obj is None:
            print("RETURNING NONE")
            return None, None, None, None

        v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]), scale=1.0)
        out = v.draw_instance_predictions(outputs['instances'][ind_obj].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg{frame}.png'
            plt.savefig(plt_name)

        # print("OBJ CLASS ID=", int(pred_classes[ind_obj].detach().cpu().numpy()))
        # pred_boxes = outputs['instances'].pred_boxes.tensor
        # pred_classes = outputs['instances'].pred_classes
        # pred_scores = outputs['instances'].scores
        obj_score = float(pred_scores[ind_obj].detach().cpu().numpy())
        obj_pred_classes = int(pred_classes[ind_obj].detach().cpu().numpy())
        obj_pred_mask = pred_masks[ind_obj].detach().cpu().numpy()


        return obj_score, obj_pred_classes, obj_pred_mask, seg_im            

            
    def detect_object_centroid(self, im, event):

        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]), scale=1.2)
        out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + '/seg_init.png'
            plt.savefig(plt_name)

        pred_masks = outputs['instances'].pred_masks
        pred_boxes = outputs['instances'].pred_boxes.tensor
        pred_classes = outputs['instances'].pred_classes
        pred_scores = outputs['instances'].scores

        obj_catids = []
        obj_scores = []
        obj_masks = []
        for segs in range(len(pred_masks)):
            if pred_scores[segs] <= self.conf_thresh_detect:
                obj_catids.append(pred_classes[segs].item())
                obj_scores.append(pred_scores[segs].item())
                obj_masks.append(pred_masks[segs])

        eulers_xyz_rad = np.radians(np.array([event.metadata['agent']['cameraHorizon'], event.metadata['agent']['rotation']['y'], 0.0]))

        rx = eulers_xyz_rad[0]
        ry = eulers_xyz_rad[1]
        rz = eulers_xyz_rad[2]
        rotation_ = self.utils.eul2rotm(-rx, -ry, rz)

        translation_ = np.array(list(event.metadata['agent']['position'].values())) + np.array([0.0, 0.675, 0.0])
        # need to invert since z is positive here by convention
        translation_[2] =  -translation_[2]

        T_world_cam = np.eye(4)
        T_world_cam[0:3,0:3] =  rotation_
        T_world_cam[0:3,3] = translation_

        if not obj_masks:
            return None, None
        elif self.center_from_mask: 

            # want an object not on the edges of the image
            sum_interior = 0
            while sum_interior==0:
                if len(obj_masks)==0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W-50, 50:self.H-50])

            depth = event.depth_frame

            xs, ys = np.meshgrid(np.linspace(-1*256/2.,1*256/2.,256), np.linspace(1*256/2.,-1*256/2., 256))
            depth = depth.reshape(1,256,256)
            xs = xs.reshape(1,256,256)
            ys = ys.reshape(1,256,256)

            xys = np.vstack((xs * depth , ys * depth, -depth, np.ones(depth.shape)))
            xys = xys.reshape(4, -1)
            xy_c0 = np.matmul(np.linalg.inv(self.K), xys)
            xyz = xy_c0.T[:,:3].reshape(256,256,3)
            xyz_obj_masked = xyz[obj_mask_focus]

            xyz_obj_masked = np.matmul(rotation_, xyz_obj_masked.T) + translation_.reshape(3,1)
            xyz_obj_mid = np.mean(xyz_obj_masked, axis=1)

            xyz_obj_mid[2] = -xyz_obj_mid[2]
        else:

            # want an object not on the edges of the image
            sum_interior = 0
            while True:
                if len(obj_masks)==0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                # print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W-50, 50:self.H-50])
                if sum_interior < 500:
                    continue # exclude too small objects


                pixel_locs_obj = np.where(obj_mask_focus.cpu().numpy())
                x_mid = np.round(np.median(pixel_locs_obj[1])/self.W, 4)
                y_mid = np.round(np.median(pixel_locs_obj[0])/self.H, 4)

                if False:
                    plt.figure(1)
                    plt.clf()
                    plt.imshow(obj_mask_focus)
                    plt.plot(np.median(pixel_locs_obj[1]), np.median(pixel_locs_obj[0]), 'x')
                    plt_name = self.homepath + '/seg_mask.png'
                    plt.savefig(plt_name)

                
                event = self.controller.step('TouchThenApplyForce', x=x_mid, y=y_mid, handDistance = 1000000.0, direction=dict(x=0.0, y=0.0, z=0.0), moveMagnitude = 0.0)
                obj_focus_id = event.metadata['actionReturn']['objectId']

                xyz_obj_mid = None
                for o in event.metadata['objects']:
                    if o['objectId'] == obj_focus_id:
                        if o['objectType'] not in self.include_classes_final:
                            continue
                        xyz_obj_mid = np.array(list(o['axisAlignedBoundingBox']['center'].values()))
                
                if xyz_obj_mid is not None:
                    break

        print("MIDPOINT=", xyz_obj_mid)
        return xyz_obj_mid, obj_mask_focus  


    def run(self, mode, total_loss, summ_writer=None):
        
        event = self.controller.step('GetReachablePositions')
        if not event.metadata['reachablePositions']:
            # Different versions this is empty/full
            event = self.controller.step(action='MoveAhead')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        # objects = np.random.choice(event.metadata['objects'], self.obj_per_scene, replace=False)
        objects = event.metadata['objects']
        objects_inds = np.arange(len(event.metadata['objects']))
        np.random.shuffle(objects_inds)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        successes = 0
        meta_obj_idx = 0
        num_obs = 0
        while True: #successes < self.obj_per_scene and meta_obj_idx <= len(event.metadata['objects']) - 1: 
            if meta_obj_idx > len(event.metadata['objects']) - 1:
                print("OUT OF OBJECT... RETURNING")
                return total_loss, None, None, None, None
                
            obj = objects[objects_inds[meta_obj_idx]]
            meta_obj_idx += 1
            print("Center object is ", obj['objectType'])
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue
            
            # Calculate distance to object center
            obj_center = np.array(list(obj['axisAlignedBoundingBox']['center'].values()))
                        
            obj_center = np.expand_dims(obj_center, axis=0)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where((distances > self.radius_min)*(distances<self.radius_max))]

            # add height from center of agent to camera
            rand_pos_int = np.random.randint(low=0, high=valid_pts.shape[0])
            pos_s = valid_pts[rand_pos_int]
            pos_s[1] = pos_s[1] + 0.675

            turn_yaw, turn_pitch = self.utils.get_rotation_to_obj(obj_center, pos_s)
            event = self.controller.step('TeleportFull', x=pos_s[0], y=pos_s[1], z=pos_s[2], rotation=dict(x=0.0, y=int(turn_yaw), z=0.0), horizon=int(turn_pitch))
            rgb = event.frame

            # get object center of a low confidence object
            obj_center, obj_mask = self.detect_object_centroid(rgb, event)

            if obj_center is None:
                print("NO LOW CONFIDENCE OBJECTS... SKIPPING...")
                continue

            # initialize object in center of FOV
            turn_yaw, turn_pitch = self.utils.get_rotation_to_obj(obj_center, pos_s)
            if mode=="train":
                pos_s_prev = pos_s
                turn_yaw_prev = turn_yaw
                turn_pitch_prev = turn_pitch
            event = self.controller.step('TeleportFull', x=pos_s[0], y=pos_s[1], z=pos_s[2], rotation=dict(x=0.0, y=int(turn_yaw), z=0.0), horizon=int(turn_pitch))
            rgb = event.frame
            seg_ims = []
            obs = []
            init_conf, obj_pred_classes, obj_mask, seg_im = self.get_detectron_conf_center_obj(rgb, obj_mask.detach().cpu().numpy())
            if init_conf is None:
                print("Nothing detected in the center... SKIPPING")
                continue
            conf_cur = init_conf
            conf_prev = init_conf
            if init_conf > self.conf_thresh_init:
                print("HIGH INITIAL CONFIDENCE... SKIPPING...")
                continue
            seg_ims.append(seg_im)
            obs.append(rgb)
            
            actions = []
            confs = []
            confs.append(conf_cur)
            episode_rewards = 0.0
            frame = 0
            while True:

                rgb_tensor = torch.FloatTensor([rgb]).permute(0, 3, 1, 2).cuda()
                
                if mode=="train":
                    torch.set_grad_enabled(True)
                    action_ind, act_proba = self.localpnet(rgb_tensor)
                elif mode=="val":
                    with torch.no_grad():
                        action_ind, act_proba = self.localpnet(rgb_tensor)

                # action_ind, act_proba = actions_probability.data.cpu().numpy()[0]

                # action_ind = np.random.choice(len(act_proba),p=act_proba)

                action_ind = int(action_ind.detach().cpu().numpy())

                action = self.action_space[action_ind]
                print("ACTION=", action)

                obs.append(rgb)
                actions.append(action_ind)
                
                # get best action in a confidence sense
                if mode=="train":
                    best_action = 4 # "DoNothing"
                    best_conf = conf_prev
                    for action_idx in [0,1,2,3]:
                        action_t = self.action_space[action_idx]
                        event_t = self.controller.step(action_t)
                        agent_position_t = np.array(list(event_t.metadata['agent']['position'].values())) + np.array([0.0, 0.675, 0.0])
                        turn_yaw_t, turn_pitch_t = self.utils.get_rotation_to_obj(obj_center, agent_position_t)
                        event_t = self.controller.step('TeleportFull', x=agent_position_t[0], y=agent_position_t[1], z=agent_position_t[2], rotation=dict(x=0.0, y=int(turn_yaw_t), z=0.0), horizon=int(turn_pitch_t))
                        rgb_t = event_t.frame
                        conf_t, _, _, _ = self.get_detectron_conf_center_obj(rgb_t, obj_mask, frame)
                        if conf_t is None:
                            conf_t = best_conf - 1 # dont want no detection
                        if conf_t > best_conf:
                            best_action = action_idx
                            best_conf = conf_t
                        _ = self.controller.step('TeleportFull', x=pos_s_prev[0], y=pos_s_prev[1], z=pos_s_prev[2], rotation=dict(x=0.0, y=int(turn_yaw_prev), z=0.0), horizon=int(turn_pitch_prev))

                    best_action = torch.LongTensor([best_action]).cuda()
                    total_loss += self.loss(act_proba, best_action)
                    num_obs += 1
                
                    print("BEST ACTION=", self.action_space[int(best_action.detach().cpu().numpy())])

                if not action=="DoNothing":
                    event = self.controller.step(action)
                    agent_position = np.array(list(event.metadata['agent']['position'].values())) + np.array([0.0, 0.675, 0.0])
                    turn_yaw, turn_pitch = self.utils.get_rotation_to_obj(obj_center, agent_position)
                    event = self.controller.step('TeleportFull', x=agent_position[0], y=agent_position[1], z=agent_position[2], rotation=dict(x=0.0, y=int(turn_yaw), z=0.0), horizon=int(turn_pitch))
                else:
                    print("Do nothing reached")
                    print("End confidence: ", conf_prev)
                    break

                if mode=="train":
                    pos_s_prev = agent_position
                    turn_yaw_prev = turn_yaw
                    turn_pitch_prev = turn_pitch

                rgb = event.frame
                conf_cur, obj_pred_classes, obj_mask_new, seg_im = self.get_detectron_conf_center_obj(rgb, obj_mask, frame)
                seg_ims.append(seg_im)
                if conf_cur is None:
                    conf_cur = conf_prev
                    seg_im = rgb
                else:
                    obj_mask = obj_mask_new
                
                if True:
                    plt.figure(1)
                    plt.clf()
                    plt.imshow(seg_im)
                    plt_name = self.homepath + f'/seg{frame}.png'
                    plt.savefig(plt_name)

                confs.append(conf_cur)

                conf_prev = conf_cur


                if conf_cur > self.conf_thresh_end:
                    print("CONFIDENCE THRESHOLD REACHED!")
                    print("End confidence: ", conf_cur)
                    break

                if frame >= self.max_frames - 1:
                    print("MAX FRAMES REACHED")
                    print("End confidence: ", conf_cur)
                    break

                frame += 1
                
                        
            return total_loss, obs, actions, seg_ims, confs
Пример #15
0
class Ai2Thor():
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True

        self.plot_loss = True
        # st()

        mapnames = []
        for i in [1, 201, 301, 401]:
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        # random.shuffle(mapnames)
        self.mapnames_train = mapnames
        self.num_episodes = len(self.mapnames_train)

        # get rest of the house in orders
        a = np.arange(2, 30)
        b = np.arange(202, 231)
        c = np.arange(302, 331)
        d = np.arange(402, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in range(a.shape[0]):
            mapname = 'FloorPlan' + str(a[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(b[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(c[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(d[i])
            mapnames.append(mapname)

        self.mapnames_test = mapnames

        self.ignore_classes = []
        # classes to save
        # self.include_classes = [
        #     'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel', 'HandTowel', 'TowelHolder', 'SoapBar',
        #     'ToiletPaper', 'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan', 'Candle', 'ScrubBrush',
        #     'Plunger', 'SinkBasin', 'Cloth', 'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed', 'Book',
        #     'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil', 'CellPhone', 'KeyChain', 'Painting', 'CreditCard',
        #     'AlarmClock', 'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk', 'Curtains', 'Dresser',
        #     'Watch', 'Television', 'WateringCan', 'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
        #     'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat', 'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit',
        #     'Shelf', 'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave', 'CoffeeMachine', 'Fork', 'Fridge',
        #     'WineBottle', 'Spatula', 'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato', 'PepperShaker',
        #     'ButterKnife', 'StoveKnob', 'Toaster', 'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
        #     'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster', 'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
        #     'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll', 'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
        #     'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window']

        # These are all classes shared between aithor and coco
        self.include_classes = [
            'Sink',
            'Toilet',
            'Bed',
            'Book',
            'CellPhone',
            'AlarmClock',
            'Laptop',
            'Chair',
            'Television',
            'RemoteControl',
            'HousePlant',
            'Ottoman',
            'ArmChair',
            'Sofa',
            'BaseballBat',
            'TennisRacket',
            'Mug',
            'Apple',
            'Bottle',
            'Microwave',
            'Fork',
            'Fridge',
            'WineBottle',
            'Cup',
            'ButterKnife',
            'Toaster',
            'Spoon',
            'Knife',
            'DiningTable',
            'Bowl',
            'Vase',
            'TeddyBear',
        ]

        self.maskrcnn_to_ithor = {
            81: 'Sink',
            70: 'Toilet',
            65: 'Bed',
            84: 'Book',
            77: 'CellPhone',
            85: 'AlarmClock',
            73: 'Laptop',
            62: 'Chair',
            72: 'Television',
            75: 'RemoteControl',
            64: 'HousePlant',
            62: 'Ottoman',
            62: 'ArmChair',
            63: 'Sofa',
            39: 'BaseballBat',
            43: 'TennisRacket',
            47: 'Mug',
            53: 'Apple',
            44: 'Bottle',
            78: 'Microwave',
            48: 'Fork',
            82: 'Fridge',
            44: 'WineBottle',
            47: 'Cup',
            49: 'ButterKnife',
            80: 'Toaster',
            50: 'Spoon',
            49: 'Knife',
            67: 'DiningTable',
            51: 'Bowl',
            86: 'Vase',
            88: 'TeddyBear',
        }

        self.ithor_to_maskrcnn = {
            'Sink': 81,
            'Toilet': 70,
            'Bed': 65,
            'Book': 84,
            'CellPhone': 77,
            'AlarmClock': 85,
            'Laptop': 73,
            'Chair': 62,
            'Television': 72,
            'RemoteControl': 75,
            'HousePlant': 64,
            'Ottoman': 62,
            'ArmChair': 62,
            'Sofa': 63,
            'BaseballBat': 39,
            'TennisRacket': 43,
            'Mug': 47,
            'Apple': 53,
            'Bottle': 44,
            'Microwave': 78,
            'Fork': 48,
            'Fridge': 82,
            'WineBottle': 44,
            'Cup': 47,
            'ButterKnife': 49,
            'Toaster': 80,
            'Spoon': 50,
            'Knife': 49,
            'DiningTable': 67,
            'Bowl': 51,
            'Vase': 86,
            'TeddyBear': 88,
        }

        self.maskrcnn_to_catname = {
            81: 'sink',
            67: 'dining table',
            65: 'bed',
            84: 'book',
            77: 'cell phone',
            70: 'toilet',
            85: 'clock',
            73: 'laptop',
            62: 'chair',
            72: 'tv',
            75: 'remote',
            64: 'potted plant',
            63: 'couch',
            39: 'baseball bat',
            43: 'tennis racket',
            47: 'cup',
            53: 'apple',
            44: 'bottle',
            78: 'microwave',
            48: 'fork',
            82: 'refrigerator',
            46: 'wine glass',
            49: 'knife',
            79: 'oven',
            80: 'toaster',
            50: 'spoon',
            67: 'dining table',
            51: 'bowl',
            86: 'vase',
            88: 'teddy bear',
        }

        self.obj_conf_dict = {
            'sink': [],
            'dining table': [],
            'bed': [],
            'book': [],
            'cell phone': [],
            'clock': [],
            'laptop': [],
            'chair': [],
            'tv': [],
            'remote': [],
            'potted plant': [],
            'couch': [],
            'baseball bat': [],
            'tennis racket': [],
            'cup': [],
            'apple': [],
            'bottle': [],
            'microwave': [],
            'fork': [],
            'refrigerator': [],
            'wine glass': [],
            'knife': [],
            'oven': [],
            'toaster': [],
            'spoon': [],
            'dining table': [],
            'bowl': [],
            'vase': [],
            'teddy bear': [],
        }

        self.data_store = {
            'sink': {},
            'dining table': {},
            'bed': {},
            'book': {},
            'cell phone': {},
            'clock': {},
            'laptop': {},
            'chair': {},
            'tv': {},
            'remote': {},
            'potted plant': {},
            'couch': {},
            'baseball bat': {},
            'tennis racket': {},
            'cup': {},
            'apple': {},
            'bottle': {},
            'microwave': {},
            'fork': {},
            'refrigerator': {},
            'wine glass': {},
            'knife': {},
            'oven': {},
            'toaster': {},
            'spoon': {},
            'dining table': {},
            'bowl': {},
            'vase': {},
            'teddy bear': {},
        }

        self.data_store_features = []
        self.feature_obj_ids = []
        self.first_time = True
        self.Softmax = nn.Softmax(dim=0)

        self.action_space = {
            0: "MoveLeft",
            1: "MoveRight",
            2: "MoveAhead",
            3: "MoveBack",
            4: "DoNothing"
        }
        self.num_actions = len(self.action_space)

        cfg_det = get_cfg()
        cfg_det.merge_from_file(
            model_zoo.get_config_file(
                "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg_det.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1  # set threshold for this model
        cfg_det.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        cfg_det.MODEL.DEVICE = 'cuda'
        self.cfg_det = cfg_det
        self.maskrcnn = DefaultPredictor(cfg_det)

        self.normalize = transforms.Compose([
            transforms.Resize(256, interpolation=PIL.Image.BILINEAR),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # Initialize vgg
        vgg16 = torchvision.models.vgg16(pretrained=True).double().cuda()
        vgg16.eval()
        print(torch.nn.Sequential(*list(vgg16.features.children())))
        self.vgg_feat_extractor = torch.nn.Sequential(
            *list(vgg16.features.children())[:-2])
        print(self.vgg_feat_extractor)
        self.vgg_mean = torch.from_numpy(
            np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
        self.vgg_std = torch.from_numpy(
            np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))

        self.conf_thresh_detect = 0.7  # for initially detecting a low confident object
        self.conf_thresh_init = 0.8  # for after turning head toward object threshold
        self.conf_thresh_end = 0.9  # if reach this then stop getting obs

        self.BATCH_SIZE = 50  # frames (not episodes) - this is approximate - it could be higher
        # self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 10
        self.val_interval = 10  #10
        self.save_interval = 50

        # self.BATCH_SIZE = 2
        # self.percentile = 70
        # self.max_iters = 100000
        # self.max_frames = 2
        # self.val_interval = 1
        # self.save_interval = 1

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25
        self.center_from_mask = False  # get object centroid from maskrcnn (True) or gt (False)

        self.obj_per_scene = 5

        mod = 'test00'

        # self.homepath = f'/home/nel/gsarch/aithor/data/test2'
        self.homepath = '/home/sirdome/katefgroup/gsarch/ithor/data/' + mod
        print(self.homepath)
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.log_freq = 1
        self.log_dir = self.homepath + '/..' + '/log_cem/' + mod
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        MAX_QUEUE = 10  # flushes when this amount waiting
        self.writer = SummaryWriter(self.log_dir,
                                    max_queue=MAX_QUEUE,
                                    flush_secs=60)

        self.W = 256
        self.H = 256

        self.fov = 90

        self.utils = Utils(self.fov, self.W, self.H)
        self.K = self.utils.get_habitat_pix_T_camX(self.fov)
        self.camera_matrix = self.utils.get_camera_matrix(
            self.W, self.H, self.fov)

        self.controller = Controller(
            scene='FloorPlan30',  # will change 
            gridSize=0.25,
            width=self.W,
            height=self.H,
            fieldOfView=self.fov,
            renderObjectImage=True,
            renderDepthImage=True,
        )

        self.init_network()

        self.run_episodes()

    def init_network(self):

        input_shape = np.array([3, self.W, self.H])

        self.localpnet = LocalPNET(input_shape=input_shape,
                                   num_actions=self.num_actions).cuda()

        self.loss = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(params=self.localpnet.parameters(),
                                          lr=0.00001)

    def batch_iteration(self, mapnames, BATCH_SIZE):

        batch = {
            "actions": [],
            "obs_all": [],
            "seg_ims": [],
            "conf_end_change": [],
            "conf_avg_change": [],
            "conf_median_change": []
        }
        iter_idx = 0
        total_loss = torch.tensor(0.0).cuda()
        num_obs = 0
        while True:

            mapname = np.random.choice(mapnames)

            # self.basepath = self.homepath + f"/{mapname}_{episode}"
            # print("BASEPATH: ", self.basepath)

            # # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            # if not os.path.exists(self.basepath):
            #     os.mkdir(self.basepath)

            self.controller.reset(scene=mapname)

            total_loss, obs, actions, seg_ims, confs = self.run(
                "train", total_loss)

            if obs is None:
                print("NO EPISODE LOSS.. SKIPPING BATCH INSTANCE")
                continue

            num_obs += len(actions)

            print("Total loss for train batch # ", iter_idx, " :", total_loss)

            confs = np.array(confs)
            conf_end_change = confs[-1] - confs[0]
            conf_avg_change = np.mean(np.diff(confs))
            conf_median_change = np.median(np.diff(confs))

            batch["actions"].append(actions)
            # These are only used for plotting
            batch["obs_all"].append(obs)
            batch["seg_ims"].append(seg_ims)
            batch["conf_end_change"].append(conf_end_change)
            batch["conf_avg_change"].append(conf_avg_change)
            batch["conf_median_change"].append(conf_median_change)

            iter_idx += 1

            # if len(batch["obs_all"]) == BATCH_SIZE:
            if num_obs >= BATCH_SIZE:
                print("NUM OBS IN BATCH=", num_obs)
                # batch["total_loss"] = total_loss
                print("Total loss for iter: ", total_loss)

                return total_loss, batch, num_obs
                # iter_idx = 0
                # total_loss = torch.tensor(0.0).cuda()
                # batch = {"actions": [], "obs_all": [], "seg_ims": [], "conf_end_change": [], "conf_avg_change": []}

    def run_episodes(self):
        self.ep_idx = 0
        # self.objects = []

        for episode in range(len(self.mapnames_train)):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames_train[episode]
            print("MAPNAME=", mapname)

            self.controller.reset(scene=mapname)

            # self.controller.start()

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run(mode="train")

            self.ep_idx += 1

        self.ep_idx = 1
        self.best_inner_prods = []
        self.pred_ids = []
        self.true_ids = []
        self.pred_catnames = []
        self.true_catnames = []
        self.pred_catnames_all = []
        self.true_catnames_all = []
        self.conf_mats = []
        # self.pred_catnames = []
        for episode in range(len(self.mapnames_test)):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames_test[episode]
            print("MAPNAME=", mapname)

            self.controller.reset(scene=mapname)

            # self.controller.start()

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run(mode="test")

            if self.ep_idx % 4 == 0:
                self.best_inner_prods = np.array(self.best_inner_prods)
                self.pred_ids = np.array(self.pred_ids)
                self.true_ids = np.array(self.true_ids)
                # for i in range(len(self.best_inner_prods)):s

                correct_pred = self.best_inner_prods[self.pred_ids ==
                                                     self.true_ids]
                incorrect_pred = self.best_inner_prods[
                    self.pred_ids != self.true_ids]

                bins = 50
                plt.figure(1)
                plt.clf()
                plt.hist([correct_pred, incorrect_pred],
                         alpha=0.5,
                         histtype='stepfilled',
                         label=['correct', 'incorrect'],
                         bins=bins)
                plt.title(f'testhouse{self.ep_idx//4}')
                plt.xlabel('inner product of nearest neighbor')
                plt.ylabel('Counts')
                plt.legend()
                plt_name = self.homepath + f'/correct_incorrect_testhouse{self.ep_idx//4}.png'
                plt.savefig(plt_name)

                conf_mat = confusion_matrix(self.pred_catnames,
                                            self.true_catnames,
                                            labels=self.include_classes)
                self.conf_mats.append(conf_mat)

                plt.figure(1)
                plt.clf()
                df_cm = pd.DataFrame(conf_mat,
                                     index=[i for i in self.include_classes],
                                     columns=[i for i in self.include_classes])
                plt.figure(figsize=(10, 7))
                sn.heatmap(df_cm, annot=True)
                plt_name = self.homepath + f'/confusion_matrix_testhouse{self.ep_idx//4}.png'
                plt.savefig(plt_name)
                # plt.show()

                self.pred_catnames_all.extend(self.pred_catnames)
                self.true_catnames_all.extend(self.true_catnames)
                self.best_inner_prods = []
                self.pred_ids = []
                self.true_ids = []
                self.true_catnames = []
                self.pred_catnames = []
                self.true_catnames = []

                conf_mat = confusion_matrix(self.pred_catnames_all,
                                            self.true_catnames_all,
                                            labels=self.include_classes)
                plt.figure(1)
                plt.clf()
                df_cm = pd.DataFrame(conf_mat,
                                     index=[i for i in self.include_classes],
                                     columns=[i for i in self.include_classes])
                plt.figure(figsize=(10, 7))
                sn.heatmap(df_cm, annot=True)
                plt_name = self.homepath + f'/confusion_matrix_testhouses_all.png'
                plt.savefig(plt_name)

            self.ep_idx += 1

        self.controller.stop()
        time.sleep(1)

    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    def get_detectron_conf_center_obj(self, im, obj_mask, frame=None):
        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        pred_masks = outputs['instances'].pred_masks
        pred_scores = outputs['instances'].scores
        pred_classes = outputs['instances'].pred_classes

        len_pad = 5

        W2_low = self.W // 2 - len_pad
        W2_high = self.W // 2 + len_pad
        H2_low = self.H // 2 - len_pad
        H2_high = self.H // 2 + len_pad

        if False:

            v = Visualizer(im[:, :, ::-1],
                           MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                           scale=1.0)
            out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
            seg_im = out.get_image()

            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all{frame}.png'
            plt.savefig(plt_name)

            seg_im[W2_low:W2_high, H2_low:H2_high, :] = 0.0
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all_mask{frame}.png'
            plt.savefig(plt_name)

        ind_obj = None
        # max_overlap = 0
        sum_obj_mask = np.sum(obj_mask)
        mask_sum_thresh = 7000
        for idx in range(pred_masks.shape[0]):
            pred_mask_cur = pred_masks[idx].detach().cpu().numpy()
            pred_masks_center = pred_mask_cur[W2_low:W2_high, H2_low:H2_high]
            sum_pred_mask_cur = np.sum(pred_mask_cur)
            # print(torch.sum(pred_masks_center))
            if np.sum(pred_masks_center) > 0:
                if np.abs(sum_pred_mask_cur - sum_obj_mask) < mask_sum_thresh:
                    ind_obj = idx
                    mask_sum_thresh = np.abs(sum_pred_mask_cur - sum_obj_mask)
                # max_overlap = torch.sum(pred_masks_center)
        if ind_obj is None:
            print("RETURNING NONE")
            return None, None, None, None

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.0)
        out = v.draw_instance_predictions(
            outputs['instances'][ind_obj].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg{frame}.png'
            plt.savefig(plt_name)

        # print("OBJ CLASS ID=", int(pred_classes[ind_obj].detach().cpu().numpy()))
        # pred_boxes = outputs['instances'].pred_boxes.tensor
        # pred_classes = outputs['instances'].pred_classes
        # pred_scores = outputs['instances'].scores
        obj_score = float(pred_scores[ind_obj].detach().cpu().numpy())
        obj_pred_classes = int(pred_classes[ind_obj].detach().cpu().numpy())
        obj_pred_mask = pred_masks[ind_obj].detach().cpu().numpy()

        return obj_score, obj_pred_classes, obj_pred_mask, seg_im

    def detect_object_centroid(self, im, event):

        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.2)
        out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + '/seg_init.png'
            plt.savefig(plt_name)

        pred_masks = outputs['instances'].pred_masks
        pred_boxes = outputs['instances'].pred_boxes.tensor
        pred_classes = outputs['instances'].pred_classes
        pred_scores = outputs['instances'].scores

        obj_catids = []
        obj_scores = []
        obj_masks = []
        for segs in range(len(pred_masks)):
            if pred_scores[segs] <= self.conf_thresh_detect:
                obj_catids.append(pred_classes[segs].item())
                obj_scores.append(pred_scores[segs].item())
                obj_masks.append(pred_masks[segs])

        eulers_xyz_rad = np.radians(
            np.array([
                event.metadata['agent']['cameraHorizon'],
                event.metadata['agent']['rotation']['y'], 0.0
            ]))

        rx = eulers_xyz_rad[0]
        ry = eulers_xyz_rad[1]
        rz = eulers_xyz_rad[2]
        rotation_ = self.utils.eul2rotm(-rx, -ry, rz)

        translation_ = np.array(
            list(event.metadata['agent']['position'].values())) + np.array(
                [0.0, 0.675, 0.0])
        # need to invert since z is positive here by convention
        translation_[2] = -translation_[2]

        T_world_cam = np.eye(4)
        T_world_cam[0:3, 0:3] = rotation_
        T_world_cam[0:3, 3] = translation_

        if not obj_masks:
            return None, None
        elif self.center_from_mask:

            # want an object not on the edges of the image
            sum_interior = 0
            while sum_interior == 0:
                if len(obj_masks) == 0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])

            depth = event.depth_frame

            xs, ys = np.meshgrid(np.linspace(-1 * 256 / 2., 1 * 256 / 2., 256),
                                 np.linspace(1 * 256 / 2., -1 * 256 / 2., 256))
            depth = depth.reshape(1, 256, 256)
            xs = xs.reshape(1, 256, 256)
            ys = ys.reshape(1, 256, 256)

            xys = np.vstack(
                (xs * depth, ys * depth, -depth, np.ones(depth.shape)))
            xys = xys.reshape(4, -1)
            xy_c0 = np.matmul(np.linalg.inv(self.K), xys)
            xyz = xy_c0.T[:, :3].reshape(256, 256, 3)
            xyz_obj_masked = xyz[obj_mask_focus]

            xyz_obj_masked = np.matmul(
                rotation_, xyz_obj_masked.T) + translation_.reshape(3, 1)
            xyz_obj_mid = np.mean(xyz_obj_masked, axis=1)

            xyz_obj_mid[2] = -xyz_obj_mid[2]
        else:

            # want an object not on the edges of the image
            sum_interior = 0
            while True:
                if len(obj_masks) == 0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                # print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])
                if sum_interior < 500:
                    continue  # exclude too small objects

                pixel_locs_obj = np.where(obj_mask_focus.cpu().numpy())
                x_mid = np.round(np.median(pixel_locs_obj[1]) / self.W, 4)
                y_mid = np.round(np.median(pixel_locs_obj[0]) / self.H, 4)

                if False:
                    plt.figure(1)
                    plt.clf()
                    plt.imshow(obj_mask_focus)
                    plt.plot(np.median(pixel_locs_obj[1]),
                             np.median(pixel_locs_obj[0]), 'x')
                    plt_name = self.homepath + '/seg_mask.png'
                    plt.savefig(plt_name)

                event = self.controller.step('TouchThenApplyForce',
                                             x=x_mid,
                                             y=y_mid,
                                             handDistance=1000000.0,
                                             direction=dict(x=0.0,
                                                            y=0.0,
                                                            z=0.0),
                                             moveMagnitude=0.0)
                obj_focus_id = event.metadata['actionReturn']['objectId']

                xyz_obj_mid = None
                for o in event.metadata['objects']:
                    if o['objectId'] == obj_focus_id:
                        if o['objectType'] not in self.include_classes_final:
                            continue
                        xyz_obj_mid = np.array(
                            list(o['axisAlignedBoundingBox']
                                 ['center'].values()))

                if xyz_obj_mid is not None:
                    break

        print("MIDPOINT=", xyz_obj_mid)
        return xyz_obj_mid, obj_mask_focus

    def run(self, mode=None, total_loss=None, summ_writer=None):

        event = self.controller.step('GetReachablePositions')
        if not event.metadata['reachablePositions']:
            # Different versions this is empty/full
            event = self.controller.step(action='MoveAhead')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        # objects = np.random.choice(event.metadata['objects'], self.obj_per_scene, replace=False)
        objects = event.metadata['objects']
        objects_inds = np.arange(len(event.metadata['objects']))
        np.random.shuffle(objects_inds)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        successes = 0
        # meta_obj_idx = 0
        num_obs = 0
        # while successes < self.obj_per_scene and meta_obj_idx <= len(event.metadata['objects']) - 1:
        for obj in objects:
            # if meta_obj_idx > len(event.metadata['objects']) - 1:
            #     print("OUT OF OBJECT... RETURNING")
            #     return total_loss, None, None, None, None

            # obj = objects[objects_inds[meta_obj_idx]]
            # meta_obj_idx += 1
            print("Center object is ", obj['objectType'])

            st()
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue

            # Calculate distance to object center
            obj_center = np.array(
                list(obj['axisAlignedBoundingBox']['center'].values()))

            obj_center = np.expand_dims(obj_center, axis=0)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where(
                (distances > self.radius_min) * (distances < self.radius_max))]

            # Bin points based on angles [vertical_angle (10 deg/bin), horizontal_angle (10 deg/bin)]
            valid_pts_shift = valid_pts - obj_center

            dz = valid_pts_shift[:, 2]
            dx = valid_pts_shift[:, 0]
            dy = valid_pts_shift[:, 1]

            # Get yaw for binning
            valid_yaw = np.degrees(np.arctan2(dz, dx))

            if mode == "train":
                nbins = 10  #20
            else:
                nbins = 5
            bins = np.linspace(-180, 180, nbins + 1)
            bin_yaw = np.digitize(valid_yaw, bins)

            num_valid_bins = np.unique(bin_yaw).size

            if False:
                import matplotlib.cm as cm
                colors = iter(cm.rainbow(np.linspace(0, 1, nbins)))
                plt.figure(2)
                plt.clf()
                print(np.unique(bin_yaw))
                for bi in range(nbins):
                    cur_bi = np.where(bin_yaw == (bi + 1))
                    points = valid_pts[cur_bi]
                    x_sample = points[:, 0]
                    z_sample = points[:, 2]
                    plt.plot(z_sample, x_sample, 'o', color=next(colors))
                plt.plot(self.nav_pts[:, 2],
                         self.nav_pts[:, 0],
                         'x',
                         color='red')
                plt.plot(obj_center[:, 2],
                         obj_center[:, 0],
                         'x',
                         color='black')
                plt_name = '/home/nel/gsarch/aithor/data/valid.png'
                plt.savefig(plt_name)

            if num_valid_bins == 0:
                continue

            if mode == "train":
                spawns_per_bin = 3  #20
            else:
                spawns_per_bin = 1  #int(self.num_views / num_valid_bins) + 2
            # print(f'spawns_per_bin: {spawns_per_bin}')

            action = "do_nothing"
            episodes = []
            valid_pts_selected = []
            camXs_T_camX0_4x4 = []
            camX0_T_camXs_4x4 = []
            origin_T_camXs = []
            origin_T_camXs_t = []
            cnt = 0
            for b in range(nbins):

                # get all angle indices in the current bin range
                inds_bin_cur = np.where(
                    bin_yaw == (b + 1))  # bins start 1 so need +1
                inds_bin_cur = list(inds_bin_cur[0])
                if len(inds_bin_cur) == 0:
                    continue

                for s in range(spawns_per_bin):

                    observations = {}

                    if len(inds_bin_cur) == 0:
                        continue

                    rand_ind = np.random.randint(0, len(inds_bin_cur))
                    s_ind = inds_bin_cur.pop(rand_ind)

                    pos_s = valid_pts[s_ind]
                    valid_pts_selected.append(pos_s)

                    # add height from center of agent to camera
                    pos_s[1] = pos_s[1] + 0.675

                    turn_yaw, turn_pitch = self.utils.get_rotation_to_obj(
                        obj_center, pos_s)

                    event = self.controller.step('TeleportFull',
                                                 x=pos_s[0],
                                                 y=pos_s[1],
                                                 z=pos_s[2],
                                                 rotation=dict(x=0.0,
                                                               y=int(turn_yaw),
                                                               z=0.0),
                                                 horizon=int(turn_pitch))

                    rgb = event.frame

                    object_id = obj['objectId']

                    instance_detections2D = event.instance_detections2D

                    if object_id not in instance_detections2D:
                        print("NOT in instance detections 2D.. continuing")
                        continue
                    obj_instance_detection2D = instance_detections2D[
                        object_id]  # [start_x, start_y, end_x, end_y]

                    max_len = np.max(
                        np.array([
                            obj_instance_detection2D[2] -
                            obj_instance_detection2D[0],
                            obj_instance_detection2D[3] -
                            obj_instance_detection2D[1]
                        ]))
                    pad_len = max_len // 8

                    if pad_len == 0:
                        print("pad len 0.. continuing")
                        continue

                    x_center = (obj_instance_detection2D[3] +
                                obj_instance_detection2D[1]) // 2
                    x_low = x_center - max_len - pad_len
                    if x_low < 0:
                        x_low = 0
                    x_high = x_center + max_len + pad_len  #x_low + max_len + 2*pad_len
                    if x_high > self.W:
                        x_high = self.W

                    y_center = (obj_instance_detection2D[2] +
                                obj_instance_detection2D[0]) // 2
                    y_low = y_center - max_len - pad_len  #-pad_len
                    if y_low < 0:
                        y_low = 0
                    y_high = y_center + max_len + pad_len  #y_low + max_len + 2*pad_len
                    if y_high > self.H:
                        y_high = self.H

                    rgb_crop = rgb[x_low:x_high, y_low:y_high, :]

                    rgb_crop = Image.fromarray(rgb_crop)

                    normalize_cropped_rgb = self.normalize(rgb_crop).unsqueeze(
                        0).double().cuda()

                    obj_features = self.vgg_feat_extractor(
                        normalize_cropped_rgb).view((512, -1))

                    obj_features = obj_features.detach().cpu().numpy()

                    # pca = PCA(n_components=10)
                    # obj_features = pca.fit_transform(obj_features.T).flatten()

                    # obj_features = torch.from_numpy(obj_features).view(-1).cuda()
                    obj_features = obj_features.flatten()

                    if mode == "train":
                        if self.first_time:
                            self.first_time = False
                            self.data_store_features = obj_features
                            # self.data_store_features = self.data_store_features.cuda()
                            self.feature_obj_ids.append(
                                self.ithor_to_maskrcnn[obj['objectType']])
                        else:
                            # self.data_store_features = torch.vstack((self.data_store_features, obj_features))
                            self.data_store_features = np.vstack(
                                (self.data_store_features, obj_features))
                            self.feature_obj_ids.append(
                                self.ithor_to_maskrcnn[obj['objectType']])

                    elif mode == "test":

                        # obj_features = obj_features.unsqueeze(0)

                        # inner_prod = torch.abs(torch.mm(obj_features, self.data_store_features.T)).squeeze()

                        # inner_prod = inner_prod.detach().cpu().numpy()

                        # dist = np.squeeze(np.abs(np.matmul(obj_features, self.data_store_features.transpose())))

                        dist = np.linalg.norm(self.data_store_features -
                                              obj_features,
                                              axis=1)

                        k = 10

                        ind_knn = list(np.argsort(dist)[:k])

                        dist_knn = np.sort(dist)[:k]
                        dist_knn_norm = list(
                            self.Softmax(torch.from_numpy(-dist_knn)).numpy())

                        match_knn_id = [
                            self.feature_obj_ids[i] for i in ind_knn
                        ]

                        # for i in range(1, len(match_knn_id)):

                        # add softmax values from the same class (probably a really complex way of doing this)
                        idx = 0
                        dist_knn_norm_add = []
                        match_knn_id_add = []
                        while True:
                            if not match_knn_id:
                                break
                            match_knn_cur = match_knn_id.pop(0)
                            dist_knn_norm_cur = dist_knn_norm.pop(0)
                            match_knn_id_add.append(match_knn_cur)
                            idxs_ = []
                            for i in range(len(match_knn_id)):
                                if match_knn_id[i] == match_knn_cur:
                                    dist_knn_norm_cur += dist_knn_norm[i]
                                    # match_knn_id_.pop(i)
                                else:
                                    idxs_.append(i)
                            match_knn_id = [match_knn_id[idx] for idx in idxs_]
                            dist_knn_norm = [
                                dist_knn_norm[idx] for idx in idxs_
                            ]
                            dist_knn_norm_add.append(dist_knn_norm_cur)

                        dist_knn_norm_add = np.array(dist_knn_norm_add)

                        dist_knn_argmax = np.argmax(dist_knn_norm_add)

                        match_nn_id = match_knn_id_add[
                            dist_knn_argmax]  #self.feature_obj_ids[ind_nn]

                        match_nn_catname = self.maskrcnn_to_ithor[match_nn_id]

                        self.best_inner_prods.append(
                            dist_knn_norm_add[dist_knn_argmax])
                        self.pred_ids.append(match_nn_id)
                        # self.pred_catnames.append(match_nn_catname)
                        self.true_ids.append(
                            self.ithor_to_maskrcnn[obj['objectType']])
                        self.pred_catnames.append(match_nn_catname)
                        self.true_catnames.append(obj['objectType'])

                        print(match_nn_catname)

                        self.data_store_features = np.vstack(
                            (self.data_store_features, obj_features))
                        self.feature_obj_ids.append(
                            self.ithor_to_maskrcnn[obj['objectType']])

                    if False:
                        normalize_cropped_rgb = np.transpose(
                            normalize_cropped_rgb.squeeze(
                                0).detach().cpu().numpy(), (1, 2, 0))
                        plt.figure(1)
                        plt.clf()
                        plt.imshow(normalize_cropped_rgb)
                        # plt_name = self.homepath + '/seg_init.png'
                        plt.figure(2)
                        plt.clf()
                        plt.imshow(rgb)
                        plt.show()

                        plt.figure(3)
                        plt.clf()
                        plt.imshow(np.array(rgb_crop))
                        plt.show()